Skip to content

Commit

Permalink
Typing for open_dataset/array/mfdataset and to_netcdf/zarr (pydata#6612)
Browse files Browse the repository at this point in the history
* type filename and chunks

* type open_dataset, open_dataarray, open_mfdataset

* type to_netcdf

* add return doc to Dataset.to_netcdf

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* replace tuple[x] by Tuple[x] for py3.8

* fix some merge errors

* add overloads to to_zarr

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix absolute import

* CamelCase type vars

* move some literal type to core.types

* add JoinOptions to core.types

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add some blank lines under bullet lists in docs

* add comments to overloads

* some more typing

* fix absolute import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Delete mypy.py

whops, accidential upload

* fix typo

* fix absolute import

* fix some absolute imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* replace Dict by dict

* fix DataArray import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix _dataset_concat arg name

* fix DataArray not imported

* remove xr import in Dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* some more typing

* replace some Sequence by Iterable

* fix wrong default in docstring

* fix docstring indentation

* fix overloads and type some tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix open_mfdataset typing

* minor update of docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove uneccesary import

* fix overloads of to_netcdf

* minor docstring update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
headtr1ck and pre-commit-ci[bot] authored May 17, 2022
1 parent e712270 commit 6b1d97a
Show file tree
Hide file tree
Showing 16 changed files with 885 additions and 364 deletions.
364 changes: 259 additions & 105 deletions xarray/backends/api.py

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import logging
import os
import time
import traceback
from typing import Any, Dict, Tuple, Type, Union
from typing import Any

import numpy as np

Expand Down Expand Up @@ -369,13 +371,13 @@ class BackendEntrypoint:
method is not mandatory.
"""

open_dataset_parameters: Union[Tuple, None] = None
open_dataset_parameters: tuple | None = None
"""list of ``open_dataset`` method parameters"""

def open_dataset(
self,
filename_or_obj: str,
drop_variables: Tuple[str] = None,
filename_or_obj: str | os.PathLike,
drop_variables: tuple[str] | None = None,
**kwargs: Any,
):
"""
Expand All @@ -384,12 +386,12 @@ def open_dataset(

raise NotImplementedError

def guess_can_open(self, filename_or_obj):
def guess_can_open(self, filename_or_obj: str | os.PathLike):
"""
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
"""

return False


BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {}
BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {}
8 changes: 5 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import JoinOptions

DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)

Expand Down Expand Up @@ -557,7 +558,7 @@ def align(self) -> None:

def align(
*objects: DataAlignable,
join="inner",
join: JoinOptions = "inner",
copy=True,
indexes=None,
exclude=frozenset(),
Expand Down Expand Up @@ -590,6 +591,7 @@ def align(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
copy : bool, optional
If ``copy=True``, data in the return values is always copied. If
``copy=False`` and reindexing is unnecessary, or can be performed with
Expand Down Expand Up @@ -764,7 +766,7 @@ def align(

def deep_align(
objects,
join="inner",
join: JoinOptions = "inner",
copy=True,
indexes=None,
exclude=frozenset(),
Expand Down Expand Up @@ -834,7 +836,7 @@ def is_alignable(obj):
if key is no_key:
out[position] = aligned_obj
else:
out[position][key] = aligned_obj
out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this?

# something went wrong: we should have replaced all sentinel values
for arg in out:
Expand Down
71 changes: 39 additions & 32 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import itertools
import warnings
from collections import Counter
from typing import Iterable, Sequence, Union
from typing import TYPE_CHECKING, Iterable, Literal, Sequence, Union

import pandas as pd

Expand All @@ -12,6 +14,9 @@
from .merge import merge
from .utils import iterate_nested

if TYPE_CHECKING:
from .types import CombineAttrsOptions, CompatOptions, JoinOptions


def _infer_concat_order_from_positions(datasets):
return dict(_infer_tile_ids_from_nested_list(datasets, ()))
Expand Down Expand Up @@ -188,10 +193,10 @@ def _combine_nd(
concat_dims,
data_vars="all",
coords="different",
compat="no_conflicts",
compat: CompatOptions = "no_conflicts",
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "drop",
):
"""
Combines an N-dimensional structure of datasets into one by applying a
Expand Down Expand Up @@ -250,10 +255,10 @@ def _combine_all_along_first_dim(
dim,
data_vars,
coords,
compat,
compat: CompatOptions,
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "drop",
):

# Group into lines of datasets which must be combined along dim
Expand All @@ -276,12 +281,12 @@ def _combine_all_along_first_dim(
def _combine_1d(
datasets,
concat_dim,
compat="no_conflicts",
compat: CompatOptions = "no_conflicts",
data_vars="all",
coords="different",
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "drop",
):
"""
Applies either concat or merge to 1D list of datasets depending on value
Expand Down Expand Up @@ -336,8 +341,8 @@ def _nested_combine(
coords,
ids,
fill_value=dtypes.NA,
join="outer",
combine_attrs="drop",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "drop",
):

if len(datasets) == 0:
Expand Down Expand Up @@ -377,15 +382,13 @@ def _nested_combine(

def combine_nested(
datasets: DATASET_HYPERCUBE,
concat_dim: Union[
str, DataArray, None, Sequence[Union[str, "DataArray", pd.Index, None]]
],
concat_dim: (str | DataArray | None | Sequence[str | DataArray | pd.Index | None]),
compat: str = "no_conflicts",
data_vars: str = "all",
coords: str = "different",
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "drop",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "drop",
) -> Dataset:
"""
Explicitly combine an N-dimensional grid of datasets into one by using a
Expand Down Expand Up @@ -603,9 +606,9 @@ def _combine_single_variable_hypercube(
fill_value=dtypes.NA,
data_vars="all",
coords="different",
compat="no_conflicts",
join="outer",
combine_attrs="no_conflicts",
compat: CompatOptions = "no_conflicts",
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "no_conflicts",
):
"""
Attempt to combine a list of Datasets into a hypercube using their
Expand Down Expand Up @@ -659,15 +662,15 @@ def _combine_single_variable_hypercube(

# TODO remove empty list default param after version 0.21, see PR4696
def combine_by_coords(
data_objects: Sequence[Union[Dataset, DataArray]] = [],
compat: str = "no_conflicts",
data_vars: str = "all",
data_objects: Iterable[Dataset | DataArray] = [],
compat: CompatOptions = "no_conflicts",
data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
coords: str = "different",
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "no_conflicts",
datasets: Sequence[Dataset] = None,
) -> Union[Dataset, DataArray]:
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "no_conflicts",
datasets: Iterable[Dataset] = None,
) -> Dataset | DataArray:
"""
Attempt to auto-magically combine the given datasets (or data arrays)
Expand Down Expand Up @@ -695,7 +698,7 @@ def combine_by_coords(
Parameters
----------
data_objects : sequence of xarray.Dataset or sequence of xarray.DataArray
data_objects : Iterable of Datasets or DataArrays
Data objects to combine.
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional
Expand All @@ -711,18 +714,19 @@ def combine_by_coords(
must be equal. The returned dataset then contains the combination
of all non-null values.
- "override": skip comparing and pick variable from first dataset
data_vars : {"minimal", "different", "all" or list of str}, optional
These data variables will be concatenated together:
* "minimal": Only data variables in which the dimension already
- "minimal": Only data variables in which the dimension already
appears are included.
* "different": Data variables which are not equal (ignoring
- "different": Data variables which are not equal (ignoring
attributes) across all datasets are also concatenated (as well as
all for which dimension already appears). Beware: this option may
load the data payload of data variables into memory if they are not
already loaded.
* "all": All data variables will be concatenated.
* list of str: The listed data variables will be concatenated, in
- "all": All data variables will be concatenated.
- list of str: The listed data variables will be concatenated, in
addition to the "minimal" data variables.
If objects are DataArrays, `data_vars` must be "all".
Expand All @@ -745,6 +749,7 @@ def combine_by_coords(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
Expand All @@ -762,6 +767,8 @@ def combine_by_coords(
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
datasets : Iterable of Datasets
Returns
-------
combined : xarray.Dataset or xarray.DataArray
Expand Down
35 changes: 26 additions & 9 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_Xarray
from .types import CombineAttrsOptions, JoinOptions, T_Xarray

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -184,7 +184,7 @@ def _enumerate(dim):
return str(alt_signature)


def result_name(objects: list) -> Any:
def result_name(objects: Iterable[Any]) -> Any:
# use the same naming heuristics as pandas:
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
Expand All @@ -196,7 +196,7 @@ def result_name(objects: list) -> Any:
return name


def _get_coords_list(args) -> list[Coordinates]:
def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
coords_list = []
for arg in args:
try:
Expand All @@ -209,23 +209,39 @@ def _get_coords_list(args) -> list[Coordinates]:


def build_output_coords_and_indexes(
args: list,
args: Iterable[Any],
signature: _UFuncSignature,
exclude_dims: AbstractSet = frozenset(),
combine_attrs: str = "override",
combine_attrs: CombineAttrsOptions = "override",
) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]:
"""Build output coordinates and indexes for an operation.
Parameters
----------
args : list
args : Iterable
List of raw operation arguments. Any valid types for xarray operations
are OK, e.g., scalars, Variable, DataArray, Dataset.
signature : _UfuncSignature
Core dimensions signature for the operation.
exclude_dims : set, optional
Dimensions excluded from the operation. Coordinates along these
dimensions are dropped.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
- "no_conflicts": attrs from all objects are combined, any that have
the same name must also have the same value.
- "drop_conflicts": attrs from all objects are combined, any that have
the same name but different values are dropped.
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
Returns
-------
Expand Down Expand Up @@ -267,10 +283,10 @@ def apply_dataarray_vfunc(
func,
*args,
signature,
join="inner",
join: JoinOptions = "inner",
exclude_dims=frozenset(),
keep_attrs="override",
):
) -> tuple[DataArray, ...] | DataArray:
"""Apply a variable level function over DataArray, Variable and/or ndarray
objects.
"""
Expand All @@ -295,6 +311,7 @@ def apply_dataarray_vfunc(
data_vars = [getattr(a, "variable", a) for a in args]
result_var = func(*data_vars)

out: tuple[DataArray, ...] | DataArray
if signature.num_outputs > 1:
out = tuple(
DataArray(
Expand Down Expand Up @@ -829,7 +846,7 @@ def apply_ufunc(
output_core_dims: Sequence[Sequence] | None = ((),),
exclude_dims: AbstractSet = frozenset(),
vectorize: bool = False,
join: str = "exact",
join: JoinOptions = "exact",
dataset_join: str = "exact",
dataset_fill_value: object = _NO_FILL_VALUE,
keep_attrs: bool | str | None = None,
Expand Down
Loading

0 comments on commit 6b1d97a

Please sign in to comment.