From 098652734f18ef7a75e9f9b47b52f67e4792046e Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Mon, 5 Dec 2022 14:59:13 +0100 Subject: [PATCH] fix for tuples with non-hashable entries --- pytorch_to_returnn/import_wrapper/wrap.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_to_returnn/import_wrapper/wrap.py b/pytorch_to_returnn/import_wrapper/wrap.py index d9a97e7f..2f03bf3e 100644 --- a/pytorch_to_returnn/import_wrapper/wrap.py +++ b/pytorch_to_returnn/import_wrapper/wrap.py @@ -10,11 +10,18 @@ def wrap(obj, *, name: str, ctx: WrapCtx): + def _fully_hashable(obj_): + if not isinstance(obj_, Hashable): + return False + if isinstance(obj_, tuple): + return all(_fully_hashable(elem) for elem in obj_) + return True + if isinstance(obj, (WrappedObject, WrappedModule)): return obj if isinstance(obj, ctx.keep_as_is_types): return obj - if isinstance(obj, Hashable) and obj in ctx.explicit_wrapped_objects: + if _fully_hashable(obj) and obj in ctx.explicit_wrapped_objects: func = ctx.explicit_wrapped_objects[obj] obj = func(obj, name=name, ctx=ctx) obj = _nested_transform(obj, lambda _x: wrap(_x, name="%s..." % name, ctx=ctx))