Skip to content

Commit

Permalink
remove inlined jax.nn.initializers definitions, resolving TODO of lev…
Browse files Browse the repository at this point in the history
…skaya et al

fixes breakage from cl/655766534 aka #21069

PiperOrigin-RevId: 655783604
  • Loading branch information
mattjj authored and jax authors committed Jul 25, 2024
1 parent 76b4c70 commit a94aea5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2693,9 +2693,9 @@ class ShapeDtypeStruct:
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "sharding", "_dll"]
named_shape = {}
named_shape = {} # type: ignore

def __init__(self, shape, dtype, sharding=None, named_shape=None):
def __init__(self, shape, dtype, named_shape=None, sharding=None):
del named_shape # ignored, vestigial
self.shape = tuple(shape)
if dtype is None:
Expand Down

0 comments on commit a94aea5

Please sign in to comment.