diff --git a/eminus/extras/torch.py b/eminus/extras/torch.py index c62f2aba..55f34855 100644 --- a/eminus/extras/torch.py +++ b/eminus/extras/torch.py @@ -39,7 +39,7 @@ def I(atoms, W): if W.ndim < 3: if len(W) == len(atoms.G2): - Wfft = np.copy(W) + Wfft = W else: if W.ndim == 1: Wfft = np.zeros(n, dtype=W.dtype) @@ -48,7 +48,7 @@ def I(atoms, W): Wfft[atoms.active] = W else: if W.shape[1] == len(atoms.G2): - Wfft = np.copy(W) + Wfft = W else: Wfft = np.zeros((atoms.Nspin, n, atoms.Nstate), dtype=W.dtype) Wfft[:, atoms.active[0]] = W @@ -132,15 +132,15 @@ def Idag(atoms, W, full=False): if W.ndim == 1: Wfft = Wfft.view(s) - F = torch.fft.fftn(Wfft, s=s, norm='backward').view(n) + F = torch.fft.fftn(Wfft, s=s, norm='forward').view(n) elif W.ndim == 2: Wfft = Wfft.view(s + (atoms.Nstate,)) - F = torch.fft.fftn(Wfft, s=s, norm='backward', dim=(0, 1, 2)).view(n, atoms.Nstate) + F = torch.fft.fftn(Wfft, s=s, norm='forward', dim=(0, 1, 2)).view(n, atoms.Nstate) else: Wfft = Wfft.view((atoms.Nspin,) + s + (atoms.Nstate,)) - F = torch.fft.fftn(Wfft, s=s, norm='backward', dim=(1, 2, 3)).view(atoms.Nspin, n, - atoms.Nstate) - F = F.detach().cpu().numpy() + F = torch.fft.fftn(Wfft, s=s, norm='forward', dim=(1, 2, 3)).view(atoms.Nspin, n, + atoms.Nstate) + F = F.detach().cpu().numpy() * n if not full: if F.ndim < 3: return F[atoms.active] @@ -164,7 +164,7 @@ def Jdag(atoms, W): if W.ndim < 3: if len(W) == len(atoms.G2): - Wfft = np.copy(W) + Wfft = W else: if W.ndim == 1: Wfft = np.zeros(n, dtype=W.dtype) @@ -173,7 +173,7 @@ def Jdag(atoms, W): Wfft[atoms.active] = W else: if W.shape[1] == len(atoms.G2): - Wfft = np.copy(W) + Wfft = W else: Wfft = np.zeros((atoms.Nspin, n, atoms.Nstate), dtype=W.dtype) Wfft[:, atoms.active[0]] = W @@ -184,12 +184,12 @@ def Jdag(atoms, W): if W.ndim == 1: Wfft = Wfft.view(s) - Finv = torch.fft.ifftn(Wfft, s=s, norm='backward').view(n) + Finv = torch.fft.ifftn(Wfft, s=s, norm='forward').view(n) elif W.ndim == 2: Wfft = Wfft.view(s + (atoms.Nstate,)) - Finv = torch.fft.ifftn(Wfft, s=s, norm='backward', dim=(0, 1, 2)).view(n, atoms.Nstate) + Finv = torch.fft.ifftn(Wfft, s=s, norm='forward', dim=(0, 1, 2)).view(n, atoms.Nstate) else: Wfft = Wfft.view((atoms.Nspin,) + s + (atoms.Nstate,)) - Finv = torch.fft.ifftn(Wfft, s=s, norm='backward', dim=(1, 2, 3)).view(atoms.Nspin, n, - atoms.Nstate) - return Finv.detach().cpu().numpy() + Finv = torch.fft.ifftn(Wfft, s=s, norm='forward', dim=(1, 2, 3)).view(atoms.Nspin, n, + atoms.Nstate) + return Finv.detach().cpu().numpy() / n