diff --git a/tests/test_torch_numerics.py b/tests/test_torch_numerics.py index dfc3fc3..830175a 100644 --- a/tests/test_torch_numerics.py +++ b/tests/test_torch_numerics.py @@ -56,8 +56,13 @@ def perform_test(atom, in_np, in_tch): This is an internal function that performs a test and compares the result between the np version and the torch version. """ + print(f"in_np = {in_np}") + print(f"in_tch = {in_tch}") res_np = atom.numeric(in_np) res_tch = EXPR2TORCH.get(type(atom)).torch_numeric(atom, in_tch) + print(f"res_np = {res_np}") + print(f"res_tch = {res_tch}") + print("="*30) if type(res_tch) is torch.Tensor: res_tch = res_tch.detach().numpy() assert np.allclose(res_np, res_tch, rtol=RTOL, atol=ATOL, equal_nan=True)