Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] add internal DTypeLike specializations #18194

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -695,12 +695,11 @@ pytype_strict_library(
],
)

pytype_strict_library(
py_library_providing_imports_info(
name = "typing",
srcs = [
"_src/typing.py",
],
deps = [":basearray"] + py_deps("numpy"),
srcs = glob(["_src/typing/**/*.py"]),
lib_rule = pytype_strict_library,
deps = [":basearray"] + py_deps("numpy") + py_deps("ml_dtypes"),
)

pytype_strict_library(
Expand Down
8 changes: 2 additions & 6 deletions jax/_src/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@
from jax import random
from jax._src import core
from jax._src import dtypes
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DTypeLikeFloat, DTypeLikeComplex
from jax._src.util import set_module

export = set_module('jax.nn.initializers')

KeyArray = Array
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
DTypeLikeComplex = Any
DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex
DTypeLikeInexact = Union[DTypeLikeFloat, DTypeLikeComplex]
RealNumeric = Any # Scalar jnp array or float

@export
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,12 @@
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.typing import Array, ArrayLike, DTypeLike, DTypeLikeInt, DTypeLikeUInt, DTypeLikeFloat
from jax._src.util import canonicalize_axis


RealArray = ArrayLike
IntegerArray = ArrayLike
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeInt = DTypeLike
DTypeLikeUInt = DTypeLike
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]

PRNGImpl = prng.PRNGImpl
Expand Down Expand Up @@ -1887,7 +1882,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:

def rademacher(key: KeyArray,
shape: Shape,
dtype: DTypeLikeInt = int) -> Array:
dtype: DTypeLike = int) -> Array:
r"""Sample from a Rademacher distribution.

The values are distributed according to the probability mass function:
Expand Down
69 changes: 23 additions & 46 deletions jax/_src/typing.py → jax/_src/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The JAX Authors.
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,51 +24,6 @@
https://github.com/google/jax/pull/11859/.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Protocol, Union
import numpy as np

from jax._src.basearray import (
Array as Array,
ArrayLike as ArrayLike,
)

DType = np.dtype

# TODO(jakevdp, froystig): make ExtendedDType a protocol
ExtendedDType = Any

class SupportsDType(Protocol):
@property
def dtype(self) -> DType: ...

# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# Unlike np.typing.DTypeLike, we exclude None, and instead require
# explicit annotations when None is acceptable.
# TODO(jakevdp): consider whether to add ExtendedDtype to the union.
DTypeLike = Union[
str, # like 'float32', 'int32'
type[Any], # like np.float32, np.int32, float, int
np.dtype, # like np.dtype('float32'), np.dtype('int32')
SupportsDType, # like jnp.float32, jnp.int32
]

# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]

class DuckTypedArray(Protocol):
@property
def dtype(self) -> DType: ...
@property
def shape(self) -> Shape: ...

# Array is a type annotation for standard JAX arrays and tracers produced by
# core functions in jax.lax and jax.numpy; it is not meant to include
# future non-standard array types like KeyArray and BInt. It is imported above.
Expand All @@ -77,3 +32,25 @@ def shape(self) -> Shape: ...
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data.

from jax._src.basearray import (
Array as Array,
ArrayLike as ArrayLike,
)

from jax._src.typing.core import (
DimSize as DimSize,
DuckTypedArray as DuckTypedArray,
Shape as Shape,
)

from jax._src.typing.dtypes import (
DType as DType,
DTypeLike as DTypeLike,
DTypeLikeBool as DTypeLikeBool,
DTypeLikeComplex as DTypeLikeComplex,
DTypeLikeFloat as DTypeLikeFloat,
DTypeLikeInt as DTypeLikeInt,
DTypeLikeUInt as DTypeLikeUInt,
ExtendedDType as ExtendedDType,
)
33 changes: 33 additions & 0 deletions jax/_src/typing/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Protocol, Union
from jax._src.typing.dtypes import DType

# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.

# TODO(jakevdp): should DimSize extensions be a protocol?
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]

class DuckTypedArray(Protocol):
@property
def dtype(self) -> DType: ...
@property
def shape(self) -> Shape: ...
116 changes: 116 additions & 0 deletions jax/_src/typing/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Literal, Protocol, TypeVar, Union
import numpy as np
import ml_dtypes

DType = np.dtype
_ScalarType = TypeVar("_ScalarType", covariant=True, bound=np.generic) # pytype: disable=not-supported-yet

# TODO(jakevdp, froystig): make ExtendedDType a protocol
ExtendedDType = Any
jakevdp marked this conversation as resolved.
Show resolved Hide resolved

class SupportsDType(Protocol[_ScalarType]):
@property
def dtype(self) -> DType[_ScalarType]: ...

# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# Unlike np.typing.DTypeLike, we exclude None, and instead require
# explicit annotations when None is acceptable.
# TODO(jakevdp): consider whether to add ExtendedDtype to the union.
DTypeLike = Union[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a bad idea to define DTypeLike as a union of the specialized DTypeLike* aliases below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTypeLike is used in public APIs, so if we change it there's a good chance the perturbation will break the builds of downstream packages. Plus the simplicity here is nice (even if there are invalid strings that it lets through).

str, # like 'float32', 'int32'
type[Any], # like np.float32, np.int32, float, int
DType[Any], # like np.dtype('float32'), np.dtype('int32')
SupportsDType[Any], # like jnp.float32, jnp.int32
]

BoolLiterals = Literal['bool', '?']

Int8Literals = Literal['int8', 'i1']
Int16Literals = Literal['int16', 'i2']
Int32Literals = Literal['int32', 'i4']
Int64Literals = Literal['int', 'int64', 'i8']

UInt8Literals = Literal['uint8', 'u1']
UInt16Literals = Literal['uint16', 'u2']
UInt32Literals = Literal['uint32', 'u4']
UInt64Literals = Literal['uint', 'uint64', 'u8']

BFloat16Literals = Literal['bfloat16']
Float16Literals = Literal['float16', 'f2']
Float32Literals = Literal['float32', 'f4']
Float64Literals = Literal['float', 'float64', 'f8']

Complex64Literals = Literal['complex64', 'c8']
Complex128Literals = Literal['complex', 'complex128', 'c16']

# TODO(jakevdp): the use of things like type[float] and type[np.floating]
# below are not strictly correct: can we do better?
jakevdp marked this conversation as resolved.
Show resolved Hide resolved

DTypeLikeBool = Union[
type[bool],
type[np.bool_],
DType[np.bool_],
SupportsDType[np.bool_],
BoolLiterals,
]

DTypeLikeUInt = Union[
type[np.unsignedinteger],
DType[np.unsignedinteger],
SupportsDType[np.unsignedinteger],
UInt8Literals,
UInt16Literals,
UInt32Literals,
UInt64Literals
]

DTypeLikeInt = Union[
type[int],
type[np.signedinteger],
DType[np.signedinteger],
SupportsDType[np.signedinteger],
Int8Literals,
Int16Literals,
Int32Literals,
Int64Literals,
]

DTypeLikeFloat = Union[
type[float],
type[np.floating],
DType[np.floating],
SupportsDType[np.floating],
DType[ml_dtypes.bfloat16],
SupportsDType[ml_dtypes.bfloat16],
BFloat16Literals,
Float16Literals,
Float32Literals,
Float64Literals
]

DTypeLikeComplex = Union[
type[complex],
type[np.complexfloating],
DType[np.complexfloating],
SupportsDType[np.complexfloating],
Complex64Literals,
Complex128Literals,
]