Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mean and sum functions #643

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions skfda/representation/_functional_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ def mean(
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is min_count removed?

) -> T:
"""Compute the mean of all the samples.

Expand All @@ -891,6 +892,9 @@ def mean(
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
A FData object with just one sample representing
Expand All @@ -902,10 +906,7 @@ def mean(
"Not implemented for that parameter combination",
)

return (
self.sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)
/ self.n_samples
)
return self

@abstractmethod
def to_grid(
Expand Down
53 changes: 43 additions & 10 deletions skfda/representation/basis/_fdatabasis.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,53 @@ def sum( # noqa: WPS125
"""
super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)

coefs = (
np.nansum(self.coefficients, axis=0) if skipna
else np.sum(self.coefficients, axis=0)
)

if min_count > 0:
valid = ~np.isnan(self.coefficients)
n_valid = np.sum(valid, axis=0)
coefs[n_valid < min_count] = np.nan
valid_functions = ~self.isna()
valid_coefficients = self.coefficients[valid_functions]

coefs = np.sum(valid_coefficients, axis=0)

return self.copy(
coefficients=coefs,
sample_names=(None,),
)

def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that min_count is not being used here. Why is that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is left for compatibility with the mean functions of FDataIrregular and Grid, but it does not make sense to use it, as you do not have measurements for each observation, but simply the observations approximated by functions.

) -> T:
"""Compute the mean of all the samples.

Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Ignored, used for compatibility with FDataGrid
and FDataIrregular.

Returns:
A FDataBasis object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am no longer sure that we want to do any validation in the abstract class. It is confusing. I would rather move the validation to the subclasses, or, if we do not want to repeat code, to a function in _utils or in a (maybe private for now) function in misc.validation.

skipna=skipna)

return (
self.sum(
axis=axis,
out=out,
keepdims=keepdims,
skipna=skipna,
)
/ np.sum(~self.isna()),
)

def var(
self: T,
Expand Down Expand Up @@ -998,7 +1031,7 @@ def isna(self) -> NDArrayBool:
Returns:
na_values (np.ndarray): Positions of NA.
"""
return np.all( # type: ignore[no-any-return]
return np.any( # type: ignore[no-any-return]
np.isnan(self.coefficients),
axis=1,
)
Expand Down
105 changes: 95 additions & 10 deletions skfda/representation/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,60 @@ def _get_points_and_values(self: T) -> Tuple[NDArrayFloat, NDArrayFloat]:

def _get_input_points(self: T) -> GridPoints:
return self.grid_points

def _compute_aggregate(
self: T,
operation: str,
*,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute a defined aggregation operation of the samples.

Args:
operation: Operation to be performed. Can be 'mean', 'sum' or
'var'.
axis: Used for compatibility with numpy. Must be None or 0.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
An FDataGrid object with just one sample representing
the aggregation of all the samples in the original object.

"""
if operation not in {'sum', 'mean', 'var'}:
raise ValueError("Invalid operation."
"Must be one of 'sum', 'mean', or 'var'.")

if skipna:
agg_func = {
'sum': np.nansum,
'mean': np.nanmean,
'var': np.nanvar
}[operation]
else:
agg_func = {
'sum': np.sum,
'mean': np.mean,
'var': np.var
}[operation]

data = agg_func(self.data_matrix, axis=0, keepdims=True)

if min_count > 0 and skipna:
valid = ~np.isnan(self.data_matrix)
n_valid = np.sum(valid, axis=0)
data[n_valid < min_count] = np.nan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a conditional be more clear?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not seem to understand where and how you are suggesting to use a conditional, the code does seem clear to me (as the author, I might be biased)


return self.copy(
data_matrix=data,
sample_names=(None,),
)

def sum( # noqa: WPS125
self: T,
Expand Down Expand Up @@ -583,19 +637,50 @@ def sum( # noqa: WPS125
"""
super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)

data = (
np.nansum(self.data_matrix, axis=0, keepdims=True) if skipna
else np.sum(self.data_matrix, axis=0, keepdims=True)
return self._compute_aggregate(
operation='sum',
skipna=skipna,
min_count=min_count,
)

if min_count > 0:
valid = ~np.isnan(self.data_matrix)
n_valid = np.sum(valid, axis=0)
data[n_valid < min_count] = np.nan
def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

return self.copy(
data_matrix=data,
sample_names=(None,),
Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
A FDataGrid object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(
axis=axis,
dtype=dtype,
out=out,
keepdims=keepdims,
skipna=skipna,
)

return self._compute_aggregate(
operation='mean',
skipna=skipna,
min_count=min_count,
)

def var(self: T, correction: int = 0) -> T:
Expand Down
59 changes: 59 additions & 0 deletions skfda/representation/irregular.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,65 @@ def sum( # noqa: WPS125
values=sum_values,
sample_names=(None,),
)

def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
An FDataIrregular object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(
axis=axis,
dtype=dtype,
out=out,
keepdims=keepdims,
skipna=skipna,
)

common_points, common_values = self._get_common_points_and_values()

if len(common_points) == 0:
raise ValueError("No common points in FDataIrregular object")

sum_function = np.nansum if skipna else np.sum
sum_values = sum_function(common_values, axis=0)

if skipna:
count_values = np.sum(~np.isnan(common_values), axis=0)
else:
count_values = np.full(sum_values.shape, self.n_samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just self.n_samples?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To operate with sum_values, it is needed in array form to fit seamlessly with the flow of the case where skipna is specified


if min_count > 0 and skipna:
count_values[count_values < min_count] = np.nan

mean_values = sum_values / count_values

return FDataIrregular(
start_indices=np.array([0]),
points=common_points,
values=mean_values,
sample_names=(None,),
)

def var(self: T, correction: int = 0) -> T:
"""Compute the variance of all the samples.
Expand Down