Skip to content

Commit

Permalink
Move dataclass registration to __init__ so that it's invoked after de…
Browse files Browse the repository at this point in the history
…serialization.

PiperOrigin-RevId: 410790272
  • Loading branch information
hamzamerzic authored and ChexDev committed Nov 18, 2021
1 parent fde0e72 commit cc0cdcc
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def mappable_dataclass(cls):
all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]

@functools.wraps(cls.__init__)
@functools.wraps(orig_init)
def new_init(self, *orig_args, **orig_kwargs):
if (orig_args and orig_kwargs) or len(orig_args) > 1:
raise ValueError(
Expand Down Expand Up @@ -126,6 +126,7 @@ def __init__(
self.unsafe_hash = unsafe_hash
self.frozen = frozen
self.mappable_dataclass = mappable_dataclass
self.registered = False

def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""
Expand Down Expand Up @@ -176,13 +177,25 @@ def _getstate(self):
def _setstate(self, state):
self.__dict__.update(state)

class_self = self
orig_init = dcls.__init__

# Patch object's __init__ such that the class is registered on creation.
# This ensures correct registration on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
if not class_self.registered:
_register_dataclass_type(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
setattr(dcls, "replace", _replace)
setattr(dcls, "__getstate__", _getstate)
setattr(dcls, "__setstate__", _setstate)
setattr(dcls, "__init__", _init)

_register_dataclass_type(dcls)
return dcls


Expand Down

0 comments on commit cc0cdcc

Please sign in to comment.