Skip to content

Commit

Permalink
Simplify transform conditions in _get_restore_parameters().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731453784
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Feb 26, 2025
1 parent c69a872 commit 0658e5f
Showing 1 changed file with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _get_param_info(
flat_param_names[input_key], meta
)
flat_input_restore_args[input_key] = maybe_input_args
elif input_key in flat_item and input_key in flat_structure:
elif input_key in flat_item:
# Key is present in both input and output.
if _has_use_fallback_transform(input_key, flat_transforms):
# Indicates that a `use_fallback` transformation was specified.
Expand All @@ -317,19 +317,17 @@ def _get_param_info(
flat_param_names[input_key], meta
)
flat_input_restore_args[input_key] = flat_restore_args[input_key]
elif transforms_default_to_original:
# Key/value is carried over from the original unchanged.
flat_param_infos[input_key] = _get_param_info(
flat_param_names[input_key], meta
)
flat_input_restore_args[input_key] = flat_restore_args[input_key]
else:
# Transform not specified.
if transforms_default_to_original:
# Key/value is carried over from the original unchanged.
flat_param_infos[input_key] = _get_param_info(
flat_param_names[input_key], meta
)
flat_input_restore_args[input_key] = flat_restore_args[input_key]
else:
# Take the value from the user-provided `item`, ignoring any value
# in the checkpoint.
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
flat_input_restore_args[input_key] = RestoreArgs()
# Take the value from the user-provided `item`, ignoring any value
# in the checkpoint.
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
flat_input_restore_args[input_key] = RestoreArgs()
else:
# No match, restoration not required since it will be dropped from the
# output.
Expand Down

0 comments on commit 0658e5f

Please sign in to comment.