diff --git a/nff/md/tully/io.py b/nff/md/tully/io.py index 0e3ff1b2..e481c849 100644 --- a/nff/md/tully/io.py +++ b/nff/md/tully/io.py @@ -450,7 +450,7 @@ def concat_and_conv(results_list, for key in keys: val = torch.cat([i[key] for i in results_list]) - if 'energy_grad' in key or 'force_nacv' in key: + if ('energy' in key and '_grad' in key) or 'force_nacv' in key: val *= conv['energy'] * conv['_grad'] val = val.reshape(*grad_shape) elif 'energy' in key or key in diabat_keys: