Skip to content

Commit

Permalink
Merge pull request #149 from naik-aakash/dosplotter_new_features
Browse files Browse the repository at this point in the history
cli invert axis, add get site all orbitals dos plot
  • Loading branch information
JaGeo authored Sep 15, 2023
2 parents 138f4f4 + 7546691 commit 6311b4d
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 41 deletions.
56 changes: 47 additions & 9 deletions lobsterpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,32 @@ def get_parser() -> argparse.ArgumentParser:
"--orbital",
"--orb",
type=str,
nargs="+",
default=None,
help="Orbital name for the site for which DOS are to be added",
)
plotting_group.add_argument(
"--site",
type=int,
nargs="+",
default=None,
help="Site index in the crystal structure for " "which DOS need to be added",
)

advanced_plotting_args = argparse.ArgumentParser(add_help=False)
advanced_plotting_args.add_argument(
"--invertaxis",
"--invert-axis",
action="store_true",
help="Will invert plot axis of DOS or COOPs COHPs or COBIS",
)
advanced_plotting_args.add_argument(
"--addtotaldos",
"--add-total-dos",
action="store_true",
help="Will all total dos to the DOS plot",
)

auto_parent = argparse.ArgumentParser(add_help=False)
auto_group = auto_parent.add_argument_group("Automatic analysis")
auto_group.add_argument(
Expand Down Expand Up @@ -456,7 +472,7 @@ def get_parser() -> argparse.ArgumentParser:
subparsers.add_parser(
"plot-dos",
aliases=["plotdos"],
parents=[input_parent, plotting_parent],
parents=[input_parent, plotting_parent, advanced_plotting_args],
help=("Will plot DOS from lobster computation."),
)
subparsers.add_parser(
Expand Down Expand Up @@ -1039,8 +1055,8 @@ def run(args):
lobs_dos = Doscar(doscar=args.doscar, structure_file=args.poscar).completedos

dos_plotter = PlainDosPlotter(summed=args.summedspins, sigma=args.sigma)

dos_plotter.add_dos(dos=lobs_dos, label="Total DOS")
if args.addtotaldos:
dos_plotter.add_dos(dos=lobs_dos, label="Total DOS")
if args.spddos:
dos_plotter.add_dos_dict(dos_dict=lobs_dos.get_spd_dos())

Expand All @@ -1054,16 +1070,38 @@ def run(args):
label = f"{element}: {orbital.name}"
dos_plotter.add_dos_dict(dos_dict={label: dos})

if args.site and args.orbital:
dos_plotter.add_site_orbital_dos(
site_index=args.site, orbital=args.orbital, dos=lobs_dos
)
elif args.site and not args.orbital:
if args.site is not None and args.orbital:
if len(args.site) > len(args.orbital):
for site in args.site:
for orbital in args.orbital:
dos_plotter.add_site_orbital_dos(
site_index=site, orbital=orbital, dos=lobs_dos
)
elif len(args.orbital) > len(args.site):
for orbital in args.orbital:
for site in args.site:
dos_plotter.add_site_orbital_dos(
site_index=site, orbital=orbital, dos=lobs_dos
)
else:
for site, orbital in zip(args.site, args.orbital):
dos_plotter.add_site_orbital_dos(
site_index=site, orbital=orbital, dos=lobs_dos
)

elif (args.site is None or not args.orbital) and (
not args.element and not args.spddos and not args.elementdos
):
raise ValueError(
"Please set both args i.e site and orbital to generate the plot"
)

plt = dos_plotter.get_plot(xlim=args.xlim, ylim=args.ylim, beta_dashed=True)
plt = dos_plotter.get_plot(
xlim=args.xlim,
ylim=args.ylim,
beta_dashed=True,
invert_axes=args.invertaxis,
)

ax = plt.gca()
ax.set_title(args.title)
Expand Down
90 changes: 59 additions & 31 deletions lobsterpy/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def add_dos(self, label: str, dos: LobsterCompleteDos) -> None:
)
densities = {Spin.up: added_densities}
else:
print(smeared_densities)
densities = smeared_densities
else:
densities = {Spin.up: dos.get_densities()}
Expand Down Expand Up @@ -324,45 +323,77 @@ def add_site_orbital_dos(self, dos: LobsterCompleteDos, orbital, site_index):
if dos.norm_vol is None:
self._norm_val = False
site = dos.structure.sites[site_index]

avail_orbs = list(dos.pdos[site])
if orbital not in avail_orbs:
if orbital not in avail_orbs and orbital != "all":
str_orbs = ", ".join(avail_orbs)
raise ValueError(
f"Requested orbital is not available for this site, "
f"available orbitals are {str_orbs}"
)

dos_obj = dos.get_site_orbital_dos(site=site, orbital=orbital)
label = site.species_string + str(site_index + 1) + f": {orbital}"

energies = dos_obj.energies
if self.summed:
if self.sigma:
smeared_densities = dos_obj.get_smeared_densities(self.sigma)
if Spin.down in smeared_densities:
added_densities = (
smeared_densities[Spin.up] + smeared_densities[Spin.down]
if orbital == "all":
for orb in avail_orbs:
dos_obj = dos.get_site_orbital_dos(site=site, orbital=orb)
label = site.species_string + str(site_index + 1) + f": {orb}"
energies = dos_obj.energies
if self.summed:
if self.sigma:
smeared_densities = dos_obj.get_smeared_densities(self.sigma)
if Spin.down in smeared_densities:
added_densities = (
smeared_densities[Spin.up]
+ smeared_densities[Spin.down]
)
densities = {Spin.up: added_densities}
else:
densities = smeared_densities
else:
densities = {Spin.up: dos_obj.get_densities()}
else:
densities = (
dos_obj.get_smeared_densities(self.sigma)
if self.sigma
else dos_obj.densities
)
densities = {Spin.up: added_densities}

efermi = dos_obj.efermi

self._doses[label] = {
"energies": energies,
"densities": densities,
"efermi": efermi,
}
else:
dos_obj = dos.get_site_orbital_dos(site=site, orbital=orbital)
label = site.species_string + str(site_index + 1) + f": {orbital}"

energies = dos_obj.energies
if self.summed:
if self.sigma:
smeared_densities = dos_obj.get_smeared_densities(self.sigma)
if Spin.down in smeared_densities:
added_densities = (
smeared_densities[Spin.up] + smeared_densities[Spin.down]
)
densities = {Spin.up: added_densities}
else:
densities = smeared_densities
else:
densities = smeared_densities
densities = {Spin.up: dos_obj.get_densities()}
else:
densities = {Spin.up: dos_obj.get_densities()}
else:
densities = (
dos_obj.get_smeared_densities(self.sigma)
if self.sigma
else dos_obj.densities
)
densities = (
dos_obj.get_smeared_densities(self.sigma)
if self.sigma
else dos_obj.densities
)

efermi = dos_obj.efermi
efermi = dos_obj.efermi

self._doses[label] = {
"energies": energies,
"densities": densities,
"efermi": efermi,
}
self._doses[label] = {
"energies": energies,
"densities": densities,
"efermi": efermi,
}

@typing.no_type_check
def get_plot(
Expand Down Expand Up @@ -972,7 +1003,6 @@ def add_icohps(self, label, icohpcollection: IcohpCollection):
def get_plot(
self,
ax: "matplotlib.axes.Axes | None" = None,
style: "matplotlib.plot.style| None" = None,
marker_size: float = 50,
marker_style: str = "o",
xlim: "Tuple[float, float] | None" = None,
Expand All @@ -984,8 +1014,6 @@ def get_plot(
Args:
ax: Existing Matplotlib Axes object to plot to.
style: matplotlib style string, if None, will
use lobsterpy style by default.
marker_size: sets the size of markers in scatter plots
marker_style: sets type of marker used in plot
xlim: Specifies the x-axis limits. Defaults to None for
Expand Down
55 changes: 54 additions & 1 deletion lobsterpy/test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,59 @@ def test_dos_plot(self, tmp_path):
test = get_parser().parse_args(args)
run(test)

os.chdir(TestDir / "TestData/NaCl_comp_range")
args = [
"plot-dos",
"--site",
"1",
"--orbital",
"all",
]

test = get_parser().parse_args(args)
run(test)

os.chdir(TestDir / "TestData/NaCl_comp_range")
args = [
"plot-dos",
"--site",
"0",
"1",
"--orbital",
"all",
]

test = get_parser().parse_args(args)
run(test)

os.chdir(TestDir / "TestData/NaCl_comp_range")
args = [
"plot-dos",
"--site",
"0",
"1",
"--orbital",
"all",
"3s",
]

test = get_parser().parse_args(args)
run(test)

os.chdir(TestDir / "TestData/NaCl_comp_range")
args = [
"plot-dos",
"--site",
"0",
"--orbital",
"all",
"3s",
"--invertaxis",
]

test = get_parser().parse_args(args)
run(test)

def test_cli_exceptions(self):
# Calc files missing exception test
with pytest.raises(ValueError) as err:
Expand Down Expand Up @@ -621,7 +674,7 @@ def get_plot_attributes(fig: Figure) -> dict | None:
ax = fig.gca()

return {
"xydata": [line.get_xydata().tolist() for line in ax.lines],
"xydata": [line.get_xydata().tolist() for line in ax.lines], # type: ignore
"facecolor": ax.get_facecolor(),
"size": fig.get_size_inches().tolist(),
}
Expand Down

0 comments on commit 6311b4d

Please sign in to comment.