Skip to content

Commit

Permalink
extras/torch: Fix some wird operator bugs
Browse files Browse the repository at this point in the history
* Remove some copys since they some unneded
* However, when removing the copy from J and Idag the localizer test fails(???)
* Also very strange: I have to normalize "by hand" which makes very little sense when looking at the torch documentation
  • Loading branch information
wangenau committed May 22, 2023
1 parent bbd1da2 commit d4aeb26
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions eminus/extras/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def I(atoms, W):

if W.ndim < 3:
if len(W) == len(atoms.G2):
Wfft = np.copy(W)
Wfft = W
else:
if W.ndim == 1:
Wfft = np.zeros(n, dtype=W.dtype)
Expand All @@ -48,7 +48,7 @@ def I(atoms, W):
Wfft[atoms.active] = W
else:
if W.shape[1] == len(atoms.G2):
Wfft = np.copy(W)
Wfft = W
else:
Wfft = np.zeros((atoms.Nspin, n, atoms.Nstate), dtype=W.dtype)
Wfft[:, atoms.active[0]] = W
Expand Down Expand Up @@ -132,15 +132,15 @@ def Idag(atoms, W, full=False):

if W.ndim == 1:
Wfft = Wfft.view(s)
F = torch.fft.fftn(Wfft, s=s, norm='backward').view(n)
F = torch.fft.fftn(Wfft, s=s, norm='forward').view(n)
elif W.ndim == 2:
Wfft = Wfft.view(s + (atoms.Nstate,))
F = torch.fft.fftn(Wfft, s=s, norm='backward', dim=(0, 1, 2)).view(n, atoms.Nstate)
F = torch.fft.fftn(Wfft, s=s, norm='forward', dim=(0, 1, 2)).view(n, atoms.Nstate)
else:
Wfft = Wfft.view((atoms.Nspin,) + s + (atoms.Nstate,))
F = torch.fft.fftn(Wfft, s=s, norm='backward', dim=(1, 2, 3)).view(atoms.Nspin, n,
atoms.Nstate)
F = F.detach().cpu().numpy()
F = torch.fft.fftn(Wfft, s=s, norm='forward', dim=(1, 2, 3)).view(atoms.Nspin, n,
atoms.Nstate)
F = F.detach().cpu().numpy() * n
if not full:
if F.ndim < 3:
return F[atoms.active]
Expand All @@ -164,7 +164,7 @@ def Jdag(atoms, W):

if W.ndim < 3:
if len(W) == len(atoms.G2):
Wfft = np.copy(W)
Wfft = W
else:
if W.ndim == 1:
Wfft = np.zeros(n, dtype=W.dtype)
Expand All @@ -173,7 +173,7 @@ def Jdag(atoms, W):
Wfft[atoms.active] = W
else:
if W.shape[1] == len(atoms.G2):
Wfft = np.copy(W)
Wfft = W
else:
Wfft = np.zeros((atoms.Nspin, n, atoms.Nstate), dtype=W.dtype)
Wfft[:, atoms.active[0]] = W
Expand All @@ -184,12 +184,12 @@ def Jdag(atoms, W):

if W.ndim == 1:
Wfft = Wfft.view(s)
Finv = torch.fft.ifftn(Wfft, s=s, norm='backward').view(n)
Finv = torch.fft.ifftn(Wfft, s=s, norm='forward').view(n)
elif W.ndim == 2:
Wfft = Wfft.view(s + (atoms.Nstate,))
Finv = torch.fft.ifftn(Wfft, s=s, norm='backward', dim=(0, 1, 2)).view(n, atoms.Nstate)
Finv = torch.fft.ifftn(Wfft, s=s, norm='forward', dim=(0, 1, 2)).view(n, atoms.Nstate)
else:
Wfft = Wfft.view((atoms.Nspin,) + s + (atoms.Nstate,))
Finv = torch.fft.ifftn(Wfft, s=s, norm='backward', dim=(1, 2, 3)).view(atoms.Nspin, n,
atoms.Nstate)
return Finv.detach().cpu().numpy()
Finv = torch.fft.ifftn(Wfft, s=s, norm='forward', dim=(1, 2, 3)).view(atoms.Nspin, n,
atoms.Nstate)
return Finv.detach().cpu().numpy() / n

0 comments on commit d4aeb26

Please sign in to comment.