Skip to content

Commit

Permalink
reinstate try-catch for loading missing modules
Browse files Browse the repository at this point in the history
Signed-off-by: Rohit Jena <[email protected]>
  • Loading branch information
rohitrango committed Jul 30, 2024
1 parent 95a8a98 commit 1bd4a73
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,13 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from
):
# GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following
for key_ in missing_keys:
s = key_.split('.')
idx = int(s[-2])
new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]])
state_dict[key_] = state_dict[new_key_]
try:
s = key_.split('.')
idx = int(s[-2])
new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]])
state_dict[key_] = state_dict[new_key_]
except:
continue

loaded_keys = list(state_dict.keys())
missing_keys = list(set(expected_keys) - set(loaded_keys))
Expand Down

0 comments on commit 1bd4a73

Please sign in to comment.