diff --git a/lobsterpy/cli.py b/lobsterpy/cli.py index df4558cb..f48a3114 100644 --- a/lobsterpy/cli.py +++ b/lobsterpy/cli.py @@ -275,16 +275,32 @@ def get_parser() -> argparse.ArgumentParser: "--orbital", "--orb", type=str, + nargs="+", default=None, help="Orbital name for the site for which DOS are to be added", ) plotting_group.add_argument( "--site", type=int, + nargs="+", default=None, help="Site index in the crystal structure for " "which DOS need to be added", ) + advanced_plotting_args = argparse.ArgumentParser(add_help=False) + advanced_plotting_args.add_argument( + "--invertaxis", + "--invert-axis", + action="store_true", + help="Will invert plot axis of DOS or COOPs COHPs or COBIS", + ) + advanced_plotting_args.add_argument( + "--addtotaldos", + "--add-total-dos", + action="store_true", + help="Will all total dos to the DOS plot", + ) + auto_parent = argparse.ArgumentParser(add_help=False) auto_group = auto_parent.add_argument_group("Automatic analysis") auto_group.add_argument( @@ -456,7 +472,7 @@ def get_parser() -> argparse.ArgumentParser: subparsers.add_parser( "plot-dos", aliases=["plotdos"], - parents=[input_parent, plotting_parent], + parents=[input_parent, plotting_parent, advanced_plotting_args], help=("Will plot DOS from lobster computation."), ) subparsers.add_parser( @@ -1039,8 +1055,8 @@ def run(args): lobs_dos = Doscar(doscar=args.doscar, structure_file=args.poscar).completedos dos_plotter = PlainDosPlotter(summed=args.summedspins, sigma=args.sigma) - - dos_plotter.add_dos(dos=lobs_dos, label="Total DOS") + if args.addtotaldos: + dos_plotter.add_dos(dos=lobs_dos, label="Total DOS") if args.spddos: dos_plotter.add_dos_dict(dos_dict=lobs_dos.get_spd_dos()) @@ -1054,16 +1070,38 @@ def run(args): label = f"{element}: {orbital.name}" dos_plotter.add_dos_dict(dos_dict={label: dos}) - if args.site and args.orbital: - dos_plotter.add_site_orbital_dos( - site_index=args.site, orbital=args.orbital, dos=lobs_dos - ) - elif args.site and not args.orbital: + if args.site is not None and args.orbital: + if len(args.site) > len(args.orbital): + for site in args.site: + for orbital in args.orbital: + dos_plotter.add_site_orbital_dos( + site_index=site, orbital=orbital, dos=lobs_dos + ) + elif len(args.orbital) > len(args.site): + for orbital in args.orbital: + for site in args.site: + dos_plotter.add_site_orbital_dos( + site_index=site, orbital=orbital, dos=lobs_dos + ) + else: + for site, orbital in zip(args.site, args.orbital): + dos_plotter.add_site_orbital_dos( + site_index=site, orbital=orbital, dos=lobs_dos + ) + + elif (args.site is None or not args.orbital) and ( + not args.element and not args.spddos and not args.elementdos + ): raise ValueError( "Please set both args i.e site and orbital to generate the plot" ) - plt = dos_plotter.get_plot(xlim=args.xlim, ylim=args.ylim, beta_dashed=True) + plt = dos_plotter.get_plot( + xlim=args.xlim, + ylim=args.ylim, + beta_dashed=True, + invert_axes=args.invertaxis, + ) ax = plt.gca() ax.set_title(args.title) diff --git a/lobsterpy/plotting/__init__.py b/lobsterpy/plotting/__init__.py index 8a4e5a35..12c14c38 100644 --- a/lobsterpy/plotting/__init__.py +++ b/lobsterpy/plotting/__init__.py @@ -296,7 +296,6 @@ def add_dos(self, label: str, dos: LobsterCompleteDos) -> None: ) densities = {Spin.up: added_densities} else: - print(smeared_densities) densities = smeared_densities else: densities = {Spin.up: dos.get_densities()} @@ -324,45 +323,77 @@ def add_site_orbital_dos(self, dos: LobsterCompleteDos, orbital, site_index): if dos.norm_vol is None: self._norm_val = False site = dos.structure.sites[site_index] - avail_orbs = list(dos.pdos[site]) - if orbital not in avail_orbs: + if orbital not in avail_orbs and orbital != "all": str_orbs = ", ".join(avail_orbs) raise ValueError( f"Requested orbital is not available for this site, " f"available orbitals are {str_orbs}" ) - dos_obj = dos.get_site_orbital_dos(site=site, orbital=orbital) - label = site.species_string + str(site_index + 1) + f": {orbital}" - - energies = dos_obj.energies - if self.summed: - if self.sigma: - smeared_densities = dos_obj.get_smeared_densities(self.sigma) - if Spin.down in smeared_densities: - added_densities = ( - smeared_densities[Spin.up] + smeared_densities[Spin.down] + if orbital == "all": + for orb in avail_orbs: + dos_obj = dos.get_site_orbital_dos(site=site, orbital=orb) + label = site.species_string + str(site_index + 1) + f": {orb}" + energies = dos_obj.energies + if self.summed: + if self.sigma: + smeared_densities = dos_obj.get_smeared_densities(self.sigma) + if Spin.down in smeared_densities: + added_densities = ( + smeared_densities[Spin.up] + + smeared_densities[Spin.down] + ) + densities = {Spin.up: added_densities} + else: + densities = smeared_densities + else: + densities = {Spin.up: dos_obj.get_densities()} + else: + densities = ( + dos_obj.get_smeared_densities(self.sigma) + if self.sigma + else dos_obj.densities ) - densities = {Spin.up: added_densities} + + efermi = dos_obj.efermi + + self._doses[label] = { + "energies": energies, + "densities": densities, + "efermi": efermi, + } + else: + dos_obj = dos.get_site_orbital_dos(site=site, orbital=orbital) + label = site.species_string + str(site_index + 1) + f": {orbital}" + + energies = dos_obj.energies + if self.summed: + if self.sigma: + smeared_densities = dos_obj.get_smeared_densities(self.sigma) + if Spin.down in smeared_densities: + added_densities = ( + smeared_densities[Spin.up] + smeared_densities[Spin.down] + ) + densities = {Spin.up: added_densities} + else: + densities = smeared_densities else: - densities = smeared_densities + densities = {Spin.up: dos_obj.get_densities()} else: - densities = {Spin.up: dos_obj.get_densities()} - else: - densities = ( - dos_obj.get_smeared_densities(self.sigma) - if self.sigma - else dos_obj.densities - ) + densities = ( + dos_obj.get_smeared_densities(self.sigma) + if self.sigma + else dos_obj.densities + ) - efermi = dos_obj.efermi + efermi = dos_obj.efermi - self._doses[label] = { - "energies": energies, - "densities": densities, - "efermi": efermi, - } + self._doses[label] = { + "energies": energies, + "densities": densities, + "efermi": efermi, + } @typing.no_type_check def get_plot( @@ -972,7 +1003,6 @@ def add_icohps(self, label, icohpcollection: IcohpCollection): def get_plot( self, ax: "matplotlib.axes.Axes | None" = None, - style: "matplotlib.plot.style| None" = None, marker_size: float = 50, marker_style: str = "o", xlim: "Tuple[float, float] | None" = None, @@ -984,8 +1014,6 @@ def get_plot( Args: ax: Existing Matplotlib Axes object to plot to. - style: matplotlib style string, if None, will - use lobsterpy style by default. marker_size: sets the size of markers in scatter plots marker_style: sets type of marker used in plot xlim: Specifies the x-axis limits. Defaults to None for diff --git a/lobsterpy/test/test_cli.py b/lobsterpy/test/test_cli.py index 883dd96c..f6724dd9 100644 --- a/lobsterpy/test/test_cli.py +++ b/lobsterpy/test/test_cli.py @@ -445,6 +445,59 @@ def test_dos_plot(self, tmp_path): test = get_parser().parse_args(args) run(test) + os.chdir(TestDir / "TestData/NaCl_comp_range") + args = [ + "plot-dos", + "--site", + "1", + "--orbital", + "all", + ] + + test = get_parser().parse_args(args) + run(test) + + os.chdir(TestDir / "TestData/NaCl_comp_range") + args = [ + "plot-dos", + "--site", + "0", + "1", + "--orbital", + "all", + ] + + test = get_parser().parse_args(args) + run(test) + + os.chdir(TestDir / "TestData/NaCl_comp_range") + args = [ + "plot-dos", + "--site", + "0", + "1", + "--orbital", + "all", + "3s", + ] + + test = get_parser().parse_args(args) + run(test) + + os.chdir(TestDir / "TestData/NaCl_comp_range") + args = [ + "plot-dos", + "--site", + "0", + "--orbital", + "all", + "3s", + "--invertaxis", + ] + + test = get_parser().parse_args(args) + run(test) + def test_cli_exceptions(self): # Calc files missing exception test with pytest.raises(ValueError) as err: @@ -621,7 +674,7 @@ def get_plot_attributes(fig: Figure) -> dict | None: ax = fig.gca() return { - "xydata": [line.get_xydata().tolist() for line in ax.lines], + "xydata": [line.get_xydata().tolist() for line in ax.lines], # type: ignore "facecolor": ax.get_facecolor(), "size": fig.get_size_inches().tolist(), }