Skip to content

Commit

Permalink
[typing] add internal DTypeLike specializations
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 19, 2023
1 parent 741b71f commit 6d87bfd
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 58 deletions.
9 changes: 4 additions & 5 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ pytype_strict_library(
] + py_deps("numpy"),
)

pytype_strict_library(
py_library_providing_imports_info(
name = "basearray",
srcs = ["_src/basearray.py"],
lib_rule = pytype_strict_library,
pytype_srcs = ["_src/basearray.pyi"],
deps = [
":sharding",
Expand Down Expand Up @@ -697,10 +698,8 @@ pytype_strict_library(

pytype_strict_library(
name = "typing",
srcs = [
"_src/typing.py",
],
deps = [":basearray"] + py_deps("numpy"),
srcs = glob(["_src/typing/**/*.py"]),
deps = [":basearray"] + py_deps("numpy") + py_deps("ml_dtypes"),
)

pytype_strict_library(
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,12 @@
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.typing import Array, ArrayLike, DTypeLike, DTypeLikeInt, DTypeLikeUInt, DTypeLikeFloat
from jax._src.util import canonicalize_axis


RealArray = ArrayLike
IntegerArray = ArrayLike
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeInt = DTypeLike
DTypeLikeUInt = DTypeLike
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]

PRNGImpl = prng.PRNGImpl
Expand Down Expand Up @@ -1887,7 +1882,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:

def rademacher(key: KeyArray,
shape: Shape,
dtype: DTypeLikeInt = int) -> Array:
dtype: DTypeLike = int) -> Array:
r"""Sample from a Rademacher distribution.
The values are distributed according to the probability mass function:
Expand Down
69 changes: 23 additions & 46 deletions jax/_src/typing.py → jax/_src/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The JAX Authors.
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,51 +24,6 @@
https://github.com/google/jax/pull/11859/.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Protocol, Union
import numpy as np

from jax._src.basearray import (
Array as Array,
ArrayLike as ArrayLike,
)

DType = np.dtype

# TODO(jakevdp, froystig): make ExtendedDType a protocol
ExtendedDType = Any

class SupportsDType(Protocol):
@property
def dtype(self) -> DType: ...

# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# Unlike np.typing.DTypeLike, we exclude None, and instead require
# explicit annotations when None is acceptable.
# TODO(jakevdp): consider whether to add ExtendedDtype to the union.
DTypeLike = Union[
str, # like 'float32', 'int32'
type[Any], # like np.float32, np.int32, float, int
np.dtype, # like np.dtype('float32'), np.dtype('int32')
SupportsDType, # like jnp.float32, jnp.int32
]

# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]

class DuckTypedArray(Protocol):
@property
def dtype(self) -> DType: ...
@property
def shape(self) -> Shape: ...

# Array is a type annotation for standard JAX arrays and tracers produced by
# core functions in jax.lax and jax.numpy; it is not meant to include
# future non-standard array types like KeyArray and BInt. It is imported above.
Expand All @@ -77,3 +32,25 @@ def shape(self) -> Shape: ...
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data.

from jax._src.basearray import (
Array as Array,
ArrayLike as ArrayLike,
)

from jax._src.typing.core import (
DimSize as DimSize,
DuckTypedArray as DuckTypedArray,
Shape as Shape,
)

from jax._src.typing.dtypes import (
DType as DType,
DTypeLike as DTypeLike,
DTypeLikeBool as DTypeLikeBool,
DTypeLikeComplex as DTypeLikeComplex,
DTypeLikeFloat as DTypeLikeFloat,
DTypeLikeInt as DTypeLikeInt,
DTypeLikeUInt as DTypeLikeUInt,
ExtendedDType as ExtendedDType,
)
33 changes: 33 additions & 0 deletions jax/_src/typing/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Protocol, Union
from jax._src.typing.dtypes import DType

# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.

# TODO(jakevdp): should DimSize extensions be a protocol?
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]

class DuckTypedArray(Protocol):
@property
def dtype(self) -> DType: ...
@property
def shape(self) -> Shape: ...
110 changes: 110 additions & 0 deletions jax/_src/typing/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Literal, Protocol, Union
import numpy as np
import ml_dtypes

DType = np.dtype

# TODO(jakevdp, froystig): make ExtendedDType a protocol
ExtendedDType = Any

class SupportsDType(Protocol):
@property
def dtype(self) -> DType: ...

# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# Unlike np.typing.DTypeLike, we exclude None, and instead require
# explicit annotations when None is acceptable.
# TODO(jakevdp): consider whether to add ExtendedDtype to the union.
DTypeLike = Union[
str, # like 'float32', 'int32'
type[Any], # like np.float32, np.int32, float, int
np.dtype, # like np.dtype('float32'), np.dtype('int32')
SupportsDType, # like jnp.float32, jnp.int32
]

BoolLiterals = Literal['bool', '?']

Int8Literals = Literal['int8', 'i1']
Int16Literals = Literal['int16', 'i2']
Int32Literals = Literal['int32', 'i4']
Int64Literals = Literal['int', 'int64', 'i8']

UInt8Literals = Literal['uint8', 'u1']
UInt16Literals = Literal['uint16', 'u2']
UInt32Literals = Literal['uint32', 'u4']
UInt64Literals = Literal['uint', 'uint64', 'u8']

BFloat16Literals = Literal['bfloat16']
Float16Literals = Literal['float16', 'f2']
Float32Literals = Literal['float32', 'f4']
Float64Literals = Literal['float', 'float64', 'f8']

Complex64Literals = Literal['complex64', 'c8']
Complex128Literals = Literal['complex', 'complex128', 'c16']

# TODO(jakevdp): the use of things like type[float] and type[np.floating]
# below are not strictly correct: can we do better?
# TODO(jakevdp): we should also add type-specific SupportsDType below.

DTypeLikeBool = Union[
type[bool],
type[np.bool_],
DType[np.bool_],
BoolLiterals,
]

DTypeLikeUInt = Union[
type[np.unsignedinteger],
DType[np.unsignedinteger],
UInt8Literals,
UInt16Literals,
UInt32Literals,
UInt64Literals
]

DTypeLikeInt = Union[
type[int],
type[np.signedinteger],
DType[np.signedinteger],
Int8Literals,
Int16Literals,
Int32Literals,
Int64Literals,
]

DTypeLikeFloat = Union[
type[float],
type[np.floating],
DType[np.floating],
DType[ml_dtypes.bfloat16],
BFloat16Literals,
Float16Literals,
Float32Literals,
Float64Literals
]

DTypeLikeComplex = Union[
type[complex],
type[np.complexfloating],
DType[np.complexfloating],
Complex64Literals,
Complex128Literals,
]

0 comments on commit 6d87bfd

Please sign in to comment.