Skip to content

Commit

Permalink
BlochWaves structure
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Dec 14, 2023
1 parent 28f6ebe commit bad71d1
Show file tree
Hide file tree
Showing 51 changed files with 344 additions and 255 deletions.
29 changes: 15 additions & 14 deletions examples/error_estimates_forces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tol = 1e-5;
# We compute the reference solution ``P_*`` from which we will compute the
# references forces.
scfres_ref = self_consistent_field(basis_ref; tol, callback=identity)
ψ_ref = DFTK.select_occupied_orbitals(basis_ref, scfres_ref.ψ, scfres_ref.occupation).ψ;
ψ_ref = DFTK.select_occupied_orbitals(scfres_ref.ψ, scfres_ref.occupation).ψ;

# We compute a variational approximation of the reference solution with
# smaller `Ecut`. `ψr`, `ρr` and `Er` are the quantities computed with `Ecut`
Expand All @@ -69,16 +69,16 @@ Ecut = 15
basis = PlaneWaveBasis(model; Ecut, kgrid)
scfres = self_consistent_field(basis; tol, callback=identity)
ψr = DFTK.transfer_blochwave(scfres.ψ, basis, basis_ref)
ρr = compute_density(basis_ref, ψr, scfres.occupation)
Er, hamr = energy_hamiltonian(basis_ref, ψr, scfres.occupation; ρ=ρr);
ρr = compute_density(ψr, scfres.occupation)
Er, hamr = energy_hamiltonian(ψr, scfres.occupation; ρ=ρr);

# We then compute several quantities that we need to evaluate the error bounds.

# - Compute the residual ``R(P)``, and remove the virtual orbitals, as required
# in [`src/scf/newton.jl`](https://github.com/JuliaMolSim/DFTK.jl/blob/fedc720dab2d194b30d468501acd0f04bd4dd3d6/src/scf/newton.jl#L121).
res = DFTK.compute_projected_gradient(basis_ref, ψr, scfres.occupation)
res, occ = DFTK.select_occupied_orbitals(basis_ref, res, scfres.occupation)
ψr = DFTK.select_occupied_orbitals(basis_ref, ψr, scfres.occupation).ψ;
res = DFTK.compute_projected_gradient(ψr, scfres.occupation)
res, occ = DFTK.select_occupied_orbitals(BlochWaves(ψr.basis, res), scfres.occupation)
ψr = DFTK.select_occupied_orbitals(ψr, scfres.occupation).ψ;

# - Compute the error ``P-P_*`` on the associated orbitals ``ϕ-ψ`` after aligning
# them: this is done by solving ``\min |ϕ - ψU|`` for ``U`` unitary matrix of
Expand Down Expand Up @@ -129,7 +129,7 @@ function apply_metric(φ, P, δφ, A::Function)
Aδφk
end
end
Mres = apply_metric(ψr, P, res, apply_inv_M);
Mres = apply_metric(ψr.data, P, res, apply_inv_M);

# We can now compute the modified residual ``R_{\rm Schur}(P)`` using a Schur
# complement to approximate the error on low-frequencies[^CDKL2021]:
Expand All @@ -149,7 +149,7 @@ Mres = apply_metric(ψr, P, res, apply_inv_M);

# - Compute the projection of the residual onto the high and low frequencies:
resLF = DFTK.transfer_blochwave(res, basis_ref, basis)
resHF = res - DFTK.transfer_blochwave(resLF, basis, basis_ref);
resHF = denest(res) - denest(DFTK.transfer_blochwave(resLF, basis, basis_ref));

# - Compute ``{\boldsymbol M}^{-1}_{22}R_2(P)``:
e2 = apply_metric(ψr, P, resHF, apply_inv_M);
Expand All @@ -163,15 +163,15 @@ e2 = apply_metric(ψr, P, resHF, apply_inv_M);
end
ΩpKe2 = DFTK.apply_Ω(e2, ψr, hamr, Λ) .+ DFTK.apply_K(basis_ref, e2, ψr, ρr, occ)
ΩpKe2 = DFTK.transfer_blochwave(ΩpKe2, basis_ref, basis)
rhs = resLF - ΩpKe2;
rhs = denest(resLF) - denest(ΩpKe2);

# - Solve the Schur system to compute ``R_{\rm Schur}(P)``: this is the most
# costly step, but inverting ``\boldsymbol{Ω} + \boldsymbol{K}`` on the small space has
# the same cost than the full SCF cycle on the small grid.
(; ψ) = DFTK.select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
e1 = DFTK.solve_ΩplusK(basis, ψ, rhs, occ; tol).δψ
(; ψ) = DFTK.select_occupied_orbitals(scfres.ψ, scfres.occupation)
e1 = DFTK.solve_ΩplusK(ψ, rhs, occ; tol).δψ
e1 = DFTK.transfer_blochwave(e1, basis, basis_ref)
res_schur = e1 + Mres;
res_schur = denest(e1) + Mres;

# ## Error estimates

Expand All @@ -197,8 +197,9 @@ relerror["F(P)"] = compute_relerror(f);
# To this end, we use the `ForwardDiff.jl` package to compute ``{\rm d}F(P)``
# using automatic differentiation.
function df(basis, occupation, ψ, δψ, ρ)
δρ = DFTK.compute_δρ(basis, ψ, δψ, occupation)
ForwardDiff.derivative-> compute_forces(basis, ψ.+ε.*δψ, occupation; ρ=ρ+ε.*δρ), 0)
δρ = DFTK.compute_δρ(ψ, δψ, occupation)
ForwardDiff.derivative-> compute_forces(BlochWaves.basis, denest(ψ).+ε.*δψ),
occupation; ρ=ρ+ε.*δρ), 0)
end;

# - Computation of the forces by a linearization argument if we have access to
Expand Down
5 changes: 5 additions & 0 deletions examples/geometry_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ function compute_scfres(x)
if isnothing(ρ)
ρ = guess_density(basis)
end
if isnothing(ψ)
ψ = BlochWaves(basis)
else
ψ = BlochWaves(basis, denest(ψ))
end
is_converged = DFTK.ScfConvergenceForce(tol / 10)
scfres = self_consistent_field(basis; ψ, ρ, is_converged, callback=identity)
ψ = scfres.ψ
Expand Down
4 changes: 2 additions & 2 deletions examples/publications/2022_cazalis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ using Plots
struct Hartree2D end
struct Term2DHartree <: DFTK.TermNonlinear end
(t::Hartree2D)(basis) = Term2DHartree()
function DFTK.ene_ops(term::Term2DHartree, basis::PlaneWaveBasis{T},
ψ, occ; ρ, kwargs...) where {T}
function DFTK.ene_ops(term::Term2DHartree, ψ::BlochWaves{T}, occ; ρ, kwargs...) where {T}
basis = ψ.basis
## 2D Fourier transform of 3D Coulomb interaction 1/|x|
poisson_green_coeffs = 2T(π) ./ [norm(G) for G in G_vectors_cart(basis)]
poisson_green_coeffs[1] = 0 # DC component
Expand Down
8 changes: 6 additions & 2 deletions ext/DFTKJLD2Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ function DFTK.load_scfres(jld::JLD2.JLDFile)

kpt_properties = (, :occupation, :eigenvalues) # Need splitting over MPI processes
for sym in kpt_properties
scfdict[sym] = jld[string(sym)][basis.krange_thisproc]
if sym ==
scfdict[sym] = nest(basis, jld[string(sym)][basis.krange_thisproc])
else
scfdict[sym] = jld[string(sym)][basis.krange_thisproc]
end
end
for sym in jld["__propertynames"]
sym in (:ham, :basis, , :energies) && continue # special
sym in kpt_properties && continue
scfdict[sym] = jld[string(sym)]
end

energies, ham = energy_hamiltonian(basis, scfdict[], scfdict[:occupation];
energies, ham = energy_hamiltonian(scfdict[], scfdict[:occupation];
ρ=scfdict[], eigenvalues=scfdict[:eigenvalues],
εF=scfdict[:εF])

Expand Down
6 changes: 4 additions & 2 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ export compute_fft_size
export G_vectors, G_vectors_cart, r_vectors, r_vectors_cart
export Gplusk_vectors, Gplusk_vectors_cart
export Kpoint
export to_composite_σG
export from_composite_σG
export BlochWaves, view_component, nest, denest
export blochwave_as_matrix
export blochwave_as_tensor
export blochwaves_as_matrices
export ifft
export irfft
export ifft!
Expand Down
50 changes: 48 additions & 2 deletions src/Psi.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
struct BlochWaves{T, Tψ, Basis <: AbstractBasis{T}}
basis::Basis
data::Vector{VT} where {VT <: AbstractArray3{Tψ}}
n_components::Int
end
function BlochWaves(basis::PlaneWaveBasis, ψ::Vector{VT}) where {VT <: AbstractArray3}
BlochWaves(basis, ψ, basis.model.n_components)
end
function BlochWaves(basis::PlaneWaveBasis, ψ::Vector{VT}) where {VT <: AbstractMatrix}
n_components = basis.model.n_components
BlochWaves(basis, [reshape(ψk, n_components, :, size(ψk, 2)) for ψk in ψ], n_components)
end
BlochWaves(basis) = BlochWaves(basis, [Array{eltype(basis), 3}(undef, 0, 0, 0)], 0)

# Helpers function to directly have access to the `data` field.
Base.getindex::BlochWaves, indices...) = ψ.data[indices...]
Base.isnothing::BlochWaves) = iszero.data)
Base.iterate::BlochWaves, args...) = Base.iterate.data, args...)
Base.length::BlochWaves) = length.data)
Base.similar::BlochWaves, args...) = similar.data, args...)
Base.size::BlochWaves, args...) = Base.size.data, args...)

Check warning on line 21 in src/Psi.jl

View check run for this annotation

Codecov / codecov/patch

src/Psi.jl#L21

Added line #L21 was not covered by tests

# T@D@: Do not change flatten the size of array by default (1:1)
# T@D@: replace with iterations with eachslice(ψk; dims=1)

@doc raw"""
view_component(ψk::AbstractArray3, σ)
View the ``σ``-th component(s) of the wave function `ψk`.
It returns a 2D matrix if `σ` is an integer or a 3D array if it is a list.
"""
@views view_component(ψk::AbstractArray3, σ) = ψk[σ, :, :]
# Apply the previous function for each k-point.
@views view_component::BlochWaves, σ) = [view_component(ψk, σ) for ψk in ψ]

Check warning on line 34 in src/Psi.jl

View check run for this annotation

Codecov / codecov/patch

src/Psi.jl#L34

Added line #L34 was not covered by tests
"""
denest(ψ::BlochWaves; σ)
Returns the arrays containing the data from the `BlochWaves` structure `ψ`.
If `σ` is given, we can ask for only some components to be extracted.
"""
@views denest::BlochWaves; σ=1:ψ.basis.model.n_components) = [view_component(ψk, σ) for ψk in ψ]
@views denest(basis, ψ::Vector; σ=1:basis.model.n_components) = [view_component(ψk, σ) for ψk in ψ]
# Wrapper around the BlochWaves creator to have an inverse to the `denest` function.
nest(basis, ψ::Vector{A}) where {A <: AbstractArray} = BlochWaves(basis, ψ)

eachcomponent(ψk::AbstractArray3) = eachslice(ψk; dims=1)
eachband(ψk::AbstractArray3) = eachslice(ψk; dims=3)

Check warning on line 47 in src/Psi.jl

View check run for this annotation

Codecov / codecov/patch

src/Psi.jl#L46-L47

Added lines #L46 - L47 were not covered by tests

@views blochwave_as_tensor(ψk::AbstractMatrix, n_components) = reshape(ψk, n_components, :, size(ψk, 2))
@views blochwave_as_matrix(ψk::AbstractArray3) = reshape(ψk, :, size(ψk, 3))
# reduce along component direction
@views blochwave_as_matorvec(ψk::AbstractArray3) = reshape(ψk, :, size(ψk, 3))
@views blochwave_as_matorvec(ψk::AbstractMatrix) = reshape(ψk, size(ψk, 2))
# Works for BlochWaves & Vector(AbstractArray3).
@views blochwaves_as_matrices(ψ) = @views [reshape(ψk, :, size(ψk, 3)) for ψk in ψ]
to_composite_σG() = nothing
from_composite_σG() = nothing
34 changes: 18 additions & 16 deletions src/densities.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
# Densities (and potentials) are represented by arrays
# ρ[ix,iy,iz,iσ] in real space, where iσ ∈ [1:n_spin_components]

# TODO: We reduce all components for the density. Will need to be though again when we merge
# the components and the spins.
"""
compute_density(basis::PlaneWaveBasis, ψ::AbstractVector, occupation::AbstractVector)
compute_density(ψ::BlochWaves, occupation::AbstractVector)
Compute the density for a wave function `ψ` discretized on the plane-wave
grid `basis`, where the individual k-points are occupied according to `occupation`.
`ψ` should be one coefficient matrix per ``k``-point.
Compute the density for a wave function `ψ` discretized on the plane-wave grid `ψ.basis`,
where the individual k-points are occupied according to `occupation`.
`ψ` should contain one coefficient matrix per ``k``-point.
It is possible to ask only for occupations higher than a certain level to be computed by
using an optional `occupation_threshold`. By default all occupation numbers are considered.
"""
@views @timing function compute_density(basis::PlaneWaveBasis{T}, ψ, occupation;
occupation_threshold=zero(T)) where {T}
S = promote_type(T, real(eltype(ψ[1])))
# TODO: We reduce all components for the density. Will need to be though again when we merge
# the components and the spins.
@views @timing function compute_density::BlochWaves{T, Tψ}, occupation;

Check warning on line 15 in src/densities.jl

View check run for this annotation

Codecov / codecov/patch

src/densities.jl#L15

Added line #L15 was not covered by tests
occupation_threshold=zero(T)) where {T, Tψ}
S = promote_type(T, real(Tψ))
# occupation should be on the CPU as we are going to be doing scalar indexing.
occupation = [to_cpu(oc) for oc in occupation]

basis = ψ.basis
mask_occ = [findall(occnk -> abs(occnk) occupation_threshold, occk)
for occk in occupation]
if all(isempty, mask_occ) # No non-zero occupations => return zero density
Expand Down Expand Up @@ -66,21 +67,22 @@ using an optional `occupation_threshold`. By default all occupation numbers are
end

# Variation in density corresponding to a variation in the orbitals and occupations.
@views @timing function compute_δρ(basis::PlaneWaveBasis{T}, ψ, δψ,
occupation, δoccupation=zero.(occupation);
@views @timing function compute_δρ(ψ::BlochWaves{T}, δψ, occupation,

Check warning on line 70 in src/densities.jl

View check run for this annotation

Codecov / codecov/patch

src/densities.jl#L70

Added line #L70 was not covered by tests
δoccupation=zero.(occupation);
occupation_threshold=zero(T)) where {T}
ForwardDiff.derivative(zero(T)) do ε
ψ_ε = [ψk .+ ε .* δψk for (ψk, δψk) in zip(ψ, δψ)]
occ_ε = [occk .+ ε .* δocck for (occk, δocck) in zip(occupation, δoccupation)]
compute_density(basis, ψ_ε, occ_ε; occupation_threshold)
compute_density(BlochWaves.basis, ψ_ε), occ_ε; occupation_threshold)
end
end

@views @timing function compute_kinetic_energy_density(basis::PlaneWaveBasis{TT}, ψ,
occupation) where {TT}
@views @timing function compute_kinetic_energy_density::BlochWaves{T, Tψ},

Check warning on line 80 in src/densities.jl

View check run for this annotation

Codecov / codecov/patch

src/densities.jl#L80

Added line #L80 was not covered by tests
occupation) where {T, Tψ}
basis = ψ.basis
@assert basis.model.n_components == 1
T = promote_type(TT, real(eltype(ψ[1])))
τ = similar(ψ[1], T, (basis.fft_size..., basis.model.n_spin_components))
TT = promote_type(T, real())
τ = similar(ψ[1], TT, (basis.fft_size..., basis.model.n_spin_components))
τ .= 0
dαψnk_real = zeros(complex(T), basis.fft_size)
for (ik, kpt) in enumerate(basis.kpoints)
Expand Down
9 changes: 5 additions & 4 deletions src/orbitals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ using Random # Used to have a generic API for CPU and GPU computations alike: s
# virtual states (or states with small occupation level for metals).
# threshold is a parameter to distinguish between states we want to keep and the
# others when using temperature. It is set to 0.0 by default, to treat with insulators.
function select_occupied_orbitals(basis, ψ, occupation; threshold=0.0)
function select_occupied_orbitals(ψ, occupation; threshold=0.0)
N = [something(findlast(x -> x > threshold, occk), 0) for occk in occupation]
selected_ψ = [@view ψk[:, :, 1:N[ik]] for (ik, ψk) in enumerate(ψ)]
selected_occ = [ occk[1:N[ik]] for (ik, occk) in enumerate(occupation)]

ψ = BlochWaves.basis, selected_ψ)
# If we have an insulator, sanity check that the orbitals we kept are the occupied ones.
if iszero(threshold)
model = basis.model
model = ψ.basis.model
n_spin = model.n_spin_components
n_bands = div(model.n_electrons, n_spin * filled_occupation(model), RoundUp)
@assert all([n_bands == size(ψk, 3) for ψk in selected_ψ])
@assert all([n_bands == size(ψk, 3) for ψk in ψ])
end
(; ψ=selected_ψ, occupation=selected_occ)
(; ψ, occupation=selected_occ)
end

# Packing routines used in direct_minimization and newton algorithms.
Expand Down
2 changes: 1 addition & 1 deletion src/postprocess/dos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function compute_ldos(ε, basis::PlaneWaveBasis{T}, eigenvalues, ψ;
# Use compute_density routine to compute LDOS, using just the modified
# weights (as "occupations") at each k-point. Note, that this automatically puts in the
# required symmetrization with respect to kpoints and BZ symmetry
compute_density(basis, ψ, weights; occupation_threshold=weight_threshold)
compute_density(ψ, weights; occupation_threshold=weight_threshold)
end
function compute_ldos(scfres::NamedTuple; ε=scfres.εF, kwargs...)
compute_ldos(ε, scfres.basis, scfres.eigenvalues, scfres.ψ; kwargs...)
Expand Down
15 changes: 8 additions & 7 deletions src/postprocess/forces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ lattice vectors. To get cartesian forces use [`compute_forces_cart`](@ref).
Returns a list of lists of forces (as SVector{3}) in the same order as the `atoms`
and `positions` in the underlying [`Model`](@ref).
"""
@timing function compute_forces(basis::PlaneWaveBasis{T}, ψ, occupation; kwargs...) where {T}
@timing function compute_forces::BlochWaves{T}, occupation; kwargs...) where {T}

Check warning on line 8 in src/postprocess/forces.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/forces.jl#L8

Added line #L8 was not covered by tests
basis = ψ.basis
# no explicit symmetrization is performed here, it is the
# responsability of each term to return symmetric forces
forces_per_term = [compute_forces(term, basis, ψ, occupation; kwargs...)
forces_per_term = [compute_forces(term, ψ, occupation; kwargs...)
for term in basis.terms]
sum(filter(!isnothing, forces_per_term))
end
Expand All @@ -19,14 +20,14 @@ Returns a list of lists of forces
`[[force for atom in positions] for (element, positions) in atoms]`
which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
"""
function compute_forces_cart(basis::PlaneWaveBasis, ψ, occupation; kwargs...)
forces_reduced = compute_forces(basis, ψ, occupation; kwargs...)
covector_red_to_cart.(basis.model, forces_reduced)
function compute_forces_cart(ψ::BlochWaves, occupation; kwargs...)
forces_reduced = compute_forces(ψ, occupation; kwargs...)
covector_red_to_cart.(ψ.basis.model, forces_reduced)
end

function compute_forces(scfres)
compute_forces(scfres.basis, scfres.ψ, scfres.occupation; scfres.ρ)
compute_forces(scfres.ψ, scfres.occupation; scfres.ρ)
end
function compute_forces_cart(scfres)
compute_forces_cart(scfres.basis, scfres.ψ, scfres.occupation; scfres.ρ)
compute_forces_cart(scfres.ψ, scfres.occupation; scfres.ρ)
end
5 changes: 3 additions & 2 deletions src/postprocess/stresses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ Compute the stresses (= 1/Vol dE/d(M*lattice), taken at M=I) of an obtained SCF
basis.kgrid, basis.symmetries_respect_rgrid,
basis.use_symmetries_for_kpoint_reduction,
basis.comm_kpts, basis.architecture)
ρ = compute_density(new_basis, scfres.ψ, scfres.occupation)
energies = energy_hamiltonian(new_basis, scfres.ψ, scfres.occupation;
ψ = BlochWaves(new_basis, denest(scfres.ψ))
ρ = compute_density(ψ, scfres.occupation)
energies = energy_hamiltonian(ψ, scfres.occupation;
ρ, scfres.eigenvalues, scfres.εF).energies
energies.total
end
Expand Down
4 changes: 2 additions & 2 deletions src/response/chi0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ function apply_χ0(ham, ψ, occupation, εF, eigenvalues, δV::AbstractArray{T};

δHψ = [RealSpaceMultiplication(basis, kpt, @views δV[:, :, :, kpt.spin]) * ψ[ik]
for (ik, kpt) in enumerate(basis.kpoints)]
(; δψ, δoccupation) = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
(; δψ, δoccupation) = apply_χ0_4P(ham, denest(ψ), occupation, εF, eigenvalues, δHψ;
occupation_threshold, kwargs_sternheimer...)
δρ = compute_δρ(basis, ψ, δψ, occupation, δoccupation; occupation_threshold)
δρ = compute_δρ(ψ, δψ, occupation, δoccupation; occupation_threshold)
δρ * normδV
end

Expand Down
Loading

0 comments on commit bad71d1

Please sign in to comment.