Skip to content

Commit

Permalink
Merge pull request #1388 from helmholtz-analytics/features/900-Improv…
Browse files Browse the repository at this point in the history
…e_load-functionality_load_multiple_files_into_one_DNDarray

Load functionality for multiple .npy files
  • Loading branch information
mrfh92 authored Jul 4, 2024
2 parents 225dc96 + b5a06b2 commit a774559
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
64 changes: 63 additions & 1 deletion heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
import warnings
import fnmatch

from typing import Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -27,7 +28,15 @@
__NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"])
__NETCDF_DIM_TEMPLATE = "{}_dim_{}"

__all__ = ["load", "load_csv", "save_csv", "save", "supports_hdf5", "supports_netcdf"]
__all__ = [
"load",
"load_csv",
"save_csv",
"save",
"supports_hdf5",
"supports_netcdf",
"load_npy_from_path",
]

try:
import netCDF4 as nc
Expand Down Expand Up @@ -1131,3 +1140,56 @@ def save(

DNDarray.save = lambda self, path, *args, **kwargs: save(self, path, *args, **kwargs)
DNDarray.save.__doc__ = save.__doc__


def load_npy_from_path(
path: str,
dtype: datatype = types.int32,
split: int = 0,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Loads multiple .npy files into one DNDarray which will be returned. The data will be concatenated along the split axis provided as input.
Parameters
----------
path : str
Path to the directory in which .npy-files are located.
dtype : datatype, optional
Data type of the resulting array.
split : int
Along which axis the loaded arrays should be concatenated.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution, default is 'heat.MPI_WORLD'
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
elif split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")

process_number = MPI_WORLD.size
file_list = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, "*.npy"):
file_list.append(file)
n_files = len(file_list)

if n_files == 0:
raise ValueError("No .npy Files were found")
if (n_files < process_number) and (process_number > 1):
raise RuntimeError("Number of processes can't exceed number of files")

rank = MPI_WORLD.rank
n_for_procs = n_files // process_number
idx = rank * n_for_procs
if rank + 1 == process_number:
n_for_procs += n_files % process_number
array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]]
larray = np.concatenate(array_list, split)
larray = torch.from_numpy(larray)

x = factories.array(larray, dtype=dtype, device=device, is_split=split, comm=comm)
return x
68 changes: 68 additions & 0 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import os
import torch
import tempfile
import random
import time
import fnmatch

import heat as ht
from .test_suites.basic_test import TestCase
Expand Down Expand Up @@ -739,3 +742,68 @@ def test_save_netcdf_exception(self):
# os.rmdir(os.getcwd() + '/tmp/')
# except OSError:
# pass

def test_load_npy_int(self):
# testing for int arrays
if ht.MPI_WORLD.rank == 0:
crea_array = []
for i in range(0, 20):
x = np.random.randint(1000, size=(random.randint(0, 30), 6, 11))
np.save(os.path.join(os.getcwd(), "heat/datasets", "int_data") + str(i), x)
crea_array.append(x)
int_array = np.concatenate(crea_array)
ht.MPI_WORLD.Barrier()

load_array = ht.load_npy_from_path(
os.path.join(os.getcwd(), "heat/datasets"), dtype=ht.int32, split=0
)
load_array_npy = load_array.numpy()

self.assertIsInstance(load_array, ht.DNDarray)
self.assertEqual(load_array.dtype, ht.int32)
if ht.MPI_WORLD.rank == 0:
self.assertTrue((load_array_npy == int_array).all)
for file in os.listdir(os.path.join(os.getcwd(), "heat/datasets")):
if fnmatch.fnmatch(file, "*.npy"):
os.remove(os.path.join(os.getcwd(), "heat/datasets", file))

def test_load_npy_float(self):
# testing for float arrays and split dimension other than 0
if ht.MPI_WORLD.rank == 0:
crea_array = []
for i in range(0, 20):
x = np.random.rand(2, random.randint(1, 10), 11)
np.save(os.path.join(os.getcwd(), "heat/datasets", "float_data") + str(i), x)
crea_array.append(x)
float_array = np.concatenate(crea_array, 1)
ht.MPI_WORLD.Barrier()

load_array = ht.load_npy_from_path(
os.path.join(os.getcwd(), "heat/datasets"), dtype=ht.float64, split=1
)
load_array_npy = load_array.numpy()
self.assertIsInstance(load_array, ht.DNDarray)
self.assertEqual(load_array.dtype, ht.float64)
if ht.MPI_WORLD.rank == 0:
self.assertTrue((load_array_npy == float_array).all)
for file in os.listdir(os.path.join(os.getcwd(), "heat/datasets")):
if fnmatch.fnmatch(file, "*.npy"):
os.remove(os.path.join(os.getcwd(), "heat/datasets", file))

def test_load_npy_exception(self):
with self.assertRaises(TypeError):
ht.load_npy_from_path(path=1, split=0)
with self.assertRaises(TypeError):
ht.load_npy_from_path("heat/datasets", split="ABC")
with self.assertRaises(ValueError):
ht.load_npy_from_path(path="heat", dtype=ht.int64, split=0)
if ht.MPI_WORLD.size > 1:
if ht.MPI_WORLD.rank == 0:
x = np.random.rand(2, random.randint(1, 10), 11)
np.save(os.path.join(os.getcwd(), "heat/datasets", "float_data"), x)
ht.MPI_WORLD.Barrier()
with self.assertRaises(RuntimeError):
ht.load_npy_from_path("heat/datasets", dtype=ht.int64, split=0)
ht.MPI_WORLD.Barrier()
if ht.MPI_WORLD.rank == 0:
os.remove(os.path.join(os.getcwd(), "heat/datasets", "float_data.npy"))

0 comments on commit a774559

Please sign in to comment.