Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sketched percentile #1420

Merged
merged 30 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
60fb1de
added sketched version of percentile computation
Apr 3, 2024
fadbcc0
added tests for sketched percentile
Apr 3, 2024
cdea1d0
added sketching to median
Apr 3, 2024
1db9e3d
added support for sketched median/percentile to preprocessing and add…
Apr 3, 2024
c36b514
removed extensive checks from auxiliary function _create_sketch
Apr 4, 2024
32da442
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Apr 4, 2024
18506b7
Update heat/core/tests/test_statistics.py
mrfh92 Apr 5, 2024
391db33
Update statistics.py
mrfh92 Apr 5, 2024
efc3293
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2024
f7f2104
Update preprocessing.py
mrfh92 Apr 5, 2024
dde023f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2024
659c7a1
Update statistics.py
mrfh92 Apr 5, 2024
a1aef0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2024
3d0b9bb
Update preprocessing.py
mrfh92 Apr 5, 2024
342d0f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2024
3195de6
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Apr 23, 2024
fd0c4a9
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 May 31, 2024
bbc452e
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Jun 5, 2024
05a6da7
some of review requests addressed
Jun 13, 2024
a89be1c
took review into account
Jun 17, 2024
872f9de
addressed all change requests
Jun 17, 2024
b11f026
added fix for single process setting
Jun 17, 2024
cf72321
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Jun 17, 2024
7822c87
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Jun 20, 2024
d9723e9
Merge branch 'main' into features/1411-Implement_sketched_percentile
ClaudiaComito Jun 21, 2024
0109908
Update heat/core/statistics.py
mrfh92 Jun 24, 2024
bb4bed2
Update heat/core/statistics.py
mrfh92 Jun 24, 2024
330d451
Merge branch 'main' into features/1411-Implement_sketched_percentile
mrfh92 Jul 1, 2024
f63b7b5
bugfix on torch 1.12/13: some indices where of type int32 instead of …
Jul 1, 2024
896f573
Merge branch 'features/1411-Implement_sketched_percentile' of github.…
Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from . import stride_tricks
from . import logical
from . import constants
from .random import randint
from warnings import warn

__all__ = [
"argmax",
Expand Down Expand Up @@ -1016,10 +1018,19 @@ def reduce_means_elementwise(output_shape_i: torch.Tensor) -> DNDarray:
DNDarray.mean.__doc__ = mean.__doc__


def median(x: DNDarray, axis: Optional[int] = None, keepdims: bool = False) -> DNDarray:
def median(
x: DNDarray,
axis: Optional[int] = None,
keepdims: bool = False,
sketched: bool = False,
sketch_size: Optional[float] = 1.0 / MPI.COMM_WORLD.size,
) -> DNDarray:
"""
Compute the median of the data along the specified axis.
Returns the median of the ``DNDarray`` elements.
Per default, the "true" median of the entire data set is computed; however, the argument
`sketched` allows to switch to a faster but less accurate version that computes
the median only on behalf of a random subset of the data set ("sketch").

Parameters
----------
Expand All @@ -1032,14 +1043,25 @@ def median(x: DNDarray, axis: Optional[int] = None, keepdims: bool = False) -> D
keepdims : bool, optional
If True, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result can broadcast correctly against the original array ``a``.

sketched : bool, optional
If True, the median is computed on a random subset of the data set ("sketch").
This is faster but less accurate. Default is False. The size of the sketch is controlled by the argument `sketch_size`.
sketch_size : float, optional
The size of the sketch as a fraction of the data set size. Default is `1./n_proc` where `n_proc` is the number of MPI processes, e.g. `n_proc = MPI.COMM_WORLD.size`. Must be in the range (0, 1).
Ignored for sketched = False.
"""
return percentile(x, q=50, axis=axis, keepdims=keepdims)
return percentile(
x, q=50, axis=axis, keepdims=keepdims, sketched=sketched, sketch_size=sketch_size
)


DNDarray.median: Callable[[DNDarray, int, bool], DNDarray] = (
lambda x, axis=None, keepdims=False: median(x, axis, keepdims)
DNDarray.median: Callable[[DNDarray, int, bool, bool, float], DNDarray] = (
lambda x, axis=None, keepdims=False, sketched=False, sketch_size=1.0 / MPI.COMM_WORLD.size: median(
x, axis, keepdims, sketched=sketched, sketch_size=sketch_size
)
)
DNDarray.mean.__doc__ = mean.__doc__
DNDarray.median.__doc__ = median.__doc__


def __merge_moments(
Expand Down Expand Up @@ -1412,10 +1434,15 @@ def percentile(
out: Optional[DNDarray] = None,
interpolation: str = "linear",
keepdims: bool = False,
sketched: bool = False,
sketch_size: Optional[float] = 1.0 / MPI.COMM_WORLD.size,
) -> DNDarray:
r"""
Compute the q-th percentile of the data along the specified axis.
Returns the q-th percentile(s) of the tensor elements.
Per default, the "true" percentile(s) of the entire data set are computed; however, the argument
`sketched` allows to switch to a faster but inaccurate version that computes
the percentile only on behalf of a random subset of the data set ("sketch").

Parameters
----------
Expand Down Expand Up @@ -1447,6 +1474,14 @@ def percentile(
keepdims : bool, optional
If True, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result can broadcast correctly against the original array x.

sketched : bool, optional
If False (default), the entire data is used and no sketching is performed.
If True, a fraction of the data to use for estimating the percentile. The fraction is determined by `sketch_size`.
sketch_size : float, optional
The fraction of the data to use for estimating the percentile; needs to be strictly between 0 and 1.
The default is 1/size of the MPI communicator, i.e., roughly the portion of the data that is anyway processed on a single process.
Ignored for sketched = False.
"""

def _local_percentile(data: torch.Tensor, axis: int, indices: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1490,13 +1525,65 @@ def _local_percentile(data: torch.Tensor, axis: int, indices: torch.Tensor) -> t

return percentile

def _create_sketch(
a: DNDarray,
axis: Union[int, None],
sketch_size_relative: Optional[float] = None,
sketch_size_absolute: Optional[int] = None,
) -> DNDarray:
"""
Create a sketch of a DNDarray along a specified axis. The sketch is created by sampling the DNDarray along the specified axis.

Parameters
----------
a : DNDarray
The DNDarray for which to create a sketch.
axis : int
The axis along which to create the sketch.
sketch_size_relative : optional, float
The size of the sketch. Fraction of samples to take, hence between 0 and 1.
sketch_size_absolute : optional, int
The size of the sketch. Number of samples to take, hence must not exceed the size of the axis along which the sketch is taken.
"""
if (sketch_size_relative is None and sketch_size_absolute is None) or (
sketch_size_relative is not None and sketch_size_absolute is not None
):
raise ValueError(
"Exactly one of sketch_size_relative and sketch_size_absolute must be specified."
)
if sketch_size_absolute is None:
sketch_size = int(sketch_size_relative * a.shape[axis])
else:
sketch_size = sketch_size_absolute

# create a random sample of indices
indices = manipulations.sort(
randint(0, a.shape[axis], sketch_size, device=a.device, dtype=types.int64)
)[0]
sketch = a.swapaxes(0, axis)
sketch = a[indices, ...].resplit_(None)
return sketch.swapaxes(0, axis)

# SANITATION
# sanitize input
if not isinstance(x, DNDarray):
raise TypeError(f"expected x to be a DNDarray, but was {type(x)}")
if isinstance(axis, (list, tuple)):
raise NotImplementedError("ht.percentile(), tuple axis not implemented yet")

if sketched:
if (
not isinstance(sketch_size, float)
or sketch_size <= 0
or (MPI.COMM_WORLD.size > 1 and sketch_size == 1)
or sketch_size > 1
):
raise ValueError(
f"If sketched=True, sketch_size must be float strictly between 0 and 1, but is {sketch_size}."
)
else:
x = _create_sketch(x, axis, sketch_size_relative=sketch_size)

if axis is None:
if x.ndim > 1:
x = x.flatten()
Expand Down
20 changes: 20 additions & 0 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,26 @@ def test_percentile(self):
with self.assertRaises(ValueError):
ht.percentile(x_ht, q, out=out_wrong_split)

def test_percentile_sketched(self):
axis, q = 0, 50
use_sketch_of_size = 0.1
q = 50
# check if it works
for split in [None, 1, 0]:
X = ht.random.rand(10 * ht.MPI_WORLD.size, 2 * ht.MPI_WORLD.size, split=split)
p = ht.percentile(X, q, axis=axis, sketched=True, sketch_size=use_sketch_of_size)
self.assertTrue(p.shape == (2 * ht.MPI_WORLD.size,))
# default sketch size
for split in [None, 1, 0]:
X = ht.random.rand(10 * ht.MPI_WORLD.size, 2 * ht.MPI_WORLD.size, split=split)
p = ht.percentile(X, q, axis=axis, sketched=True)
self.assertTrue(p.shape == (2 * ht.MPI_WORLD.size,))
# check if it raises correct errors
with self.assertRaises(ValueError):
ht.percentile(X, q, axis=axis, sketched=True, sketch_size=1.1)
with self.assertRaises(ValueError):
ht.percentile(X, q, axis=axis, sketched=True, sketch_size=10)

def test_skew(self):
x = ht.zeros((2, 3, 4))
with self.assertRaises(ValueError):
Expand Down
34 changes: 31 additions & 3 deletions heat/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ class RobustScaler(ht.TransformMixin, ht.BaseEstimator):
the quantile range (defaults to IQR: Interquartile Range); this routine is similar
to ``sklearn.preprocessing.RobustScaler``.

Per default, the "true" median and IQR of the entire data set is computed; however, the argument
`sketched` allows to switch to a faster but inaccurate version that computes
median and IQR only on behalf of a random subset of the data set ("sketch") of size `sketch_size`.

The underyling data set to be scaled must be stored as a 2D-`DNDarray` of shape (n_datapoints, n_features).
Each feature is centered and scaled independently.

Expand All @@ -470,6 +474,14 @@ class RobustScaler(ht.TransformMixin, ht.BaseEstimator):
unit_variance : not yet supported.
raises ``NotImplementedError``

sketched : bool, default=False
If `True`, use a sketch of the data set to compute the median and IQR.
This is faster but less accurate. The size of the sketch is determined by the argument `sketch_size`.

sketch_size : float, default=1./ht.MPI_WORLD.size
Fraction of the data set to be used for the sketch if `sketched=True`. The default value is 1/N, where N is the number of MPI processes.
Ignored if `sketched=False`.

Attributes
----------
center_ : DNDarray of shape (n_features,)
Expand All @@ -490,11 +502,15 @@ def __init__(
quantile_range: Tuple[float, float] = (25.0, 75.0),
copy: bool = True,
unit_variance: bool = False,
sketched: bool = False,
sketch_size: Optional[float] = 1.0 / ht.MPI_WORLD.size,
):
self.with_centering = with_centering
self.with_scaling = with_scaling
self.quantile_range = quantile_range
self.copy = copy
self.sketched = sketched
self.sketch_size = sketch_size
if not with_centering and not with_scaling:
raise ValueError(
"Both centering and scaling are disabled, thus RobustScaler could do nothing. At least one of with_scaling or with_centering must be True."
Expand Down Expand Up @@ -525,10 +541,22 @@ def fit(self, X: ht.DNDarray) -> Self:
"""
_is_2D_float_DNDarray(X)
if self.with_centering:
self.center_ = ht.median(X, axis=0)
self.center_ = ht.median(
X, axis=0, sketched=self.sketched, sketch_size=self.sketch_size
)
if self.with_scaling:
self.iqr_ = ht.percentile(X, self.quantile_range[1], axis=0) - ht.percentile(
X, self.quantile_range[0], axis=0
self.iqr_ = ht.percentile(
X,
self.quantile_range[1],
axis=0,
sketched=self.sketched,
sketch_size=self.sketch_size,
) - ht.percentile(
X,
self.quantile_range[0],
axis=0,
sketched=self.sketched,
sketch_size=self.sketch_size,
)

# if length of iqr is close to zero, do not scale this feature
Expand Down
25 changes: 24 additions & 1 deletion heat/preprocessing/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_MaxAbsScaler(self):

class TestRobustScaler(TestCase):
def test_RobustScaler(self):
for split in [0]:
for split in [0, 1]:
for with_centering in [False, True]:
for with_scaling in [False, True]:
if not with_centering and not with_scaling:
Expand Down Expand Up @@ -348,3 +348,26 @@ def test_RobustScaler(self):
scaler.fit(ht.zeros((10, 10, 10), dtype=ht.float32))
with self.assertRaises(TypeError):
scaler.fit(ht.zeros(10, 10, dtype=ht.int32))

def test_robust_scaler_sketched(self):
for split in [0, 1]:
with_centering = True
with_scaling = True
copy = True
X = _generate_test_data_set(
MPI.COMM_WORLD.Get_size() * 10,
MPI.COMM_WORLD.Get_size() * 4,
split=split,
dtype=ht.float32,
)
scaler = ht.preprocessing.RobustScaler(
quantile_range=(24.0, 76.0),
copy=copy,
with_centering=with_centering,
with_scaling=with_scaling,
sketched=True,
)
scaler.fit(X)
Y = scaler.transform(X)
Y = scaler.inverse_transform(Y)
self.assertTrue(ht.allclose(X, Y, atol=atol_inv))
Loading