Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PhononBSPlotter.plot_compare() add legend labels #3507

Merged
merged 7 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.9"

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
Expand Down Expand Up @@ -89,7 +89,7 @@ jobs:
- name: Check out repo
uses: actions/checkout@v4

- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.11"
Expand Down Expand Up @@ -134,7 +134,7 @@ jobs:
id-token: write
steps:
- name: Set up Python 3.11
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.11

Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
Expand Down Expand Up @@ -97,8 +97,12 @@ jobs:
continue-on-error: true # This is not critical to succeed.
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel
python -m pip install numpy cython packaging

# install ase from main branch until FrechetCellFilter is released
# TODO remove pip install git+https://gitlab.com/ase/ase
pip install git+https://gitlab.com/ase/ase

python -m pip install -e '.[dev,optional]'

- name: pytest split ${{ matrix.split }}
Expand Down
146 changes: 84 additions & 62 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def get_plot(
ax (Axes): An existing axes object onto which the plot will be
added. If None, a new figure will be created.
"""
legend = legend or {"fontsize": 30}
legend = legend or {}
legend.setdefault("fontsize", 30)
unit = freq_units(units)

n_colors = max(3, len(self._doses))
Expand Down Expand Up @@ -212,9 +213,9 @@ def get_plot(
ax.set_xlabel(rf"$\mathrm{{Frequencies\ ({unit.label})}}$", fontsize=legend.get("fontsize", 30))
ax.set_ylabel(r"$\mathrm{Density\ of\ states}$", fontsize=legend.get("fontsize", 30))

ax.legend()
ax.legend(**legend)
legend_text = ax.get_legend().get_texts() # all the text.Text instance in the legend
plt.setp(legend_text, **legend)
plt.setp(legend_text)
plt.tight_layout()
return ax

Expand Down Expand Up @@ -261,10 +262,12 @@ def show(
class PhononBSPlotter:
"""Class to plot or get data to facilitate the plot of band structure objects."""

def __init__(self, bs: PhononBandStructureSymmLine) -> None:
def __init__(self, bs: PhononBandStructureSymmLine, label: str | None = None) -> None:
"""
Args:
bs: A PhononBandStructureSymmLine object.
label: A label for the plot. Defaults to None for no label. Esp. useful with
the plot_compare method to distinguish the two band structures.
"""
if not isinstance(bs, PhononBandStructureSymmLine):
raise ValueError(
Expand All @@ -273,7 +276,8 @@ def __init__(self, bs: PhononBandStructureSymmLine) -> None:
"not along symmetry lines won't work)"
)
self._bs = bs
self._nb_bands = self._bs.nb_bands
self._nb_bands = bs.nb_bands
self._label = label

def _make_ticks(self, ax: Axes) -> Axes:
"""Utility private method to add ticks to a band structure."""
Expand All @@ -282,8 +286,8 @@ def _make_ticks(self, ax: Axes) -> Axes:
uniq_d = []
uniq_l = []
temp_ticks = list(zip(ticks["distance"], ticks["label"]))
for i, tt in enumerate(temp_ticks):
if i == 0:
for idx, tt in enumerate(temp_ticks):
if idx == 0:
uniq_d.append(tt[0])
uniq_l.append(tt[1])
else:
Expand All @@ -293,13 +297,13 @@ def _make_ticks(self, ax: Axes) -> Axes:
ax.set_xticks(uniq_d)
ax.set_xticklabels(uniq_l)

for i in range(len(ticks["label"])):
if ticks["label"][i] is not None:
for idx in range(len(ticks["label"])):
if ticks["label"][idx] is not None:
# don't print the same label twice
if i != 0:
ax.axvline(ticks["distance"][i], color="k")
if idx != 0:
ax.axvline(ticks["distance"][idx], color="k")
else:
ax.axvline(ticks["distance"][i], color="k")
ax.axvline(ticks["distance"][idx], color="k")
return ax

def bs_plot_data(self) -> dict[str, Any]:
Expand All @@ -319,12 +323,14 @@ def bs_plot_data(self) -> dict[str, Any]:

ticks = self.get_ticks()

for b in self._bs.branches:
for branch in self._bs.branches:
frequency.append([])
distance.append([self._bs.distance[j] for j in range(b["start_index"], b["end_index"] + 1)])
distance.append([self._bs.distance[j] for j in range(branch["start_index"], branch["end_index"] + 1)])

for i in range(self._nb_bands):
frequency[-1].append([self._bs.bands[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
for idx in range(self._nb_bands):
frequency[-1].append(
[self._bs.bands[idx][j] for j in range(branch["start_index"], branch["end_index"] + 1)]
)

return {
"ticks": ticks,
Expand All @@ -334,29 +340,29 @@ def bs_plot_data(self) -> dict[str, Any]:
}

def get_plot(
self, ylim: float | None = None, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz"
self, ylim: float | None = None, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", **kwargs
) -> Axes:
"""Get a matplotlib object for the bandstructure plot.

Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
Defaults to "thz".
**kwargs: passed to ax.plot function.
"""
u = freq_units(units)

ax = pretty_plot(12, 8)

band_linewidth = 1

data = self.bs_plot_data()
for d in range(len(data["distances"])):
for i in range(self._nb_bands):
for idx in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["frequency"][d][i][j] * u.factor for j in range(len(data["distances"][d]))],
[data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))],
"b-",
linewidth=band_linewidth,
**kwargs,
)

self._make_ticks(ax)
Expand Down Expand Up @@ -387,8 +393,8 @@ def _get_weight(self, vec: np.ndarray, indices: list[list[int]]) -> np.ndarray:
"""
num_atom = int(self._nb_bands / 3)
new_vec = np.zeros(num_atom)
for i in range(num_atom):
new_vec[i] = np.linalg.norm(vec[i * 3 : i * 3 + 3])
for idx in range(num_atom):
new_vec[idx] = np.linalg.norm(vec[idx * 3 : idx * 3 + 3])
# get the projectors for each group
gw = []
norm_f = 0
Expand Down Expand Up @@ -445,10 +451,10 @@ def get_proj_plot(
if site_comb == "element":
assert 2 <= len(elements) <= 4, "the compound must have 2, 3 or 4 unique elements"
indices: list[list[int]] = [[] for _ in range(len(elements))]
for i, ele in enumerate(self._bs.structure.species):
for idx, elem in enumerate(self._bs.structure.species):
for j, unique_species in enumerate(self._bs.structure.elements):
if ele == unique_species:
indices[j].append(i)
if elem == unique_species:
indices[j].append(idx)
else:
assert isinstance(site_comb, list)
assert 2 <= len(site_comb) <= 4, "the length of site_comb must be 2, 3 or 4"
Expand Down Expand Up @@ -503,7 +509,7 @@ def get_proj_plot(
elif site_comb == "element":
labels = [e.symbol for e in self._bs.structure.elements]
else:
labels = [f"{i}" for i in range(len(site_comb))]
labels = [f"{idx}" for idx in range(len(site_comb))]
if len(indices) == 2:
BSDOSPlotter._rb_line(ax, labels[0], labels[1], "best")
elif len(indices) == 3:
Expand Down Expand Up @@ -606,7 +612,12 @@ def get_ticks(self) -> dict[str, list]:
return {"distance": tick_distance, "label": tick_labels}

def plot_compare(
self, other_plotter: PhononBSPlotter, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz"
self,
other_plotter: PhononBSPlotter,
units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz",
labels: tuple[str, str] | None = None,
legend_kwargs: dict | None = None,
**kwargs,
) -> Axes:
"""Plot two band structure for comparison. One is in red the other in blue.
The two band structures need to be defined on the same symmetry lines!
Expand All @@ -617,50 +628,60 @@ def plot_compare(
other_plotter: another PhononBSPlotter object defined along the same symmetry lines
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
Defaults to 'thz'.
labels: labels for the two band structures. Defaults to None, which will use the
legend of the two PhononBSPlotter objects.
legend_kwargs: kwargs passed to ax.legend().
**kwargs: passed to ax.plot().

Returns:
a matplotlib object with both band structures
"""
u = freq_units(units)
unit = freq_units(units)
legend_kwargs = legend_kwargs or {}
legend_kwargs.setdefault("fontsize", 22)

data_orig = self.bs_plot_data()
data = other_plotter.bs_plot_data()

if len(data_orig["distances"]) != len(data["distances"]):
raise ValueError("The two objects are not compatible.")

ax = self.get_plot(units=units)
band_linewidth = 1
line_width = kwargs.setdefault("linewidth", 1)

ax = self.get_plot(units=units, **kwargs)
for band_idx in range(other_plotter._nb_bands):
for dist_idx in range(len(data_orig["distances"])):
ax.plot(
data_orig["distances"][dist_idx],
[
data["frequency"][dist_idx][band_idx][j] * u.factor
data["frequency"][dist_idx][band_idx][j] * unit.factor
for j in range(len(data_orig["distances"][dist_idx]))
],
"r-",
linewidth=band_linewidth,
**kwargs,
)

# add legend showing which color correspond to which band structure
if labels is None and self._label and other_plotter._label:
labels = (self._label, other_plotter._label)
if labels:
ax.plot([], [], "r-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "b-", label=labels[1], linewidth=3 * line_width)
ax.legend(**legend_kwargs)

return ax

def plot_brillouin(self) -> None:
"""Plot the Brillouin zone."""
# get labels and lines
labels = {}
for q in self._bs.qpoints:
if q.label:
labels[q.label] = q.frac_coords
for q_pt in self._bs.qpoints:
if q_pt.label:
labels[q_pt.label] = q_pt.frac_coords

lines = []
for b in self._bs.branches:
lines.append(
[
self._bs.qpoints[b["start_index"]].frac_coords,
self._bs.qpoints[b["end_index"]].frac_coords,
]
)
lines.append([self._bs.qpoints[b["start_index"]].frac_coords, self._bs.qpoints[b["end_index"]].frac_coords])

plot_brillouin_zone(self._bs.lattice_rec, lines=lines, labels=labels)

Expand Down Expand Up @@ -979,14 +1000,18 @@ def bs_plot_data(self) -> dict[str, Any]:

ticks = self.get_ticks()

for b in self._bs.branches:
for branch in self._bs.branches:
frequency.append([])
gruneisen.append([])
distance.append([self._bs.distance[j] for j in range(b["start_index"], b["end_index"] + 1)])
distance.append([self._bs.distance[j] for j in range(branch["start_index"], branch["end_index"] + 1)])

for i in range(self._nb_bands):
frequency[-1].append([self._bs.bands[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
gruneisen[-1].append([self._bs.gruneisen[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
for idx in range(self._nb_bands):
frequency[-1].append(
[self._bs.bands[idx][j] for j in range(branch["start_index"], branch["end_index"] + 1)]
)
gruneisen[-1].append(
[self._bs.gruneisen[idx][j] for j in range(branch["start_index"], branch["end_index"] + 1)]
)

return {
"ticks": ticks,
Expand All @@ -996,29 +1021,26 @@ def bs_plot_data(self) -> dict[str, Any]:
"lattice": self._bs.lattice_rec.as_dict(),
}

def get_plot_gs(self, ylim: float | None = None) -> Axes:
"""Get a matplotlib object for the gruneisen bandstructure plot.
def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
"""Get a matplotlib object for the Gruneisen bandstructure plot.

Args:
ylim: Specify the y-axis (gruneisen) limits; by default None let
the code choose.
**kwargs: additional keywords passed to ax.plot().
"""
ax = pretty_plot(12, 8)

# band_linewidth = 1
kwargs.setdefault("linewidth", 2)
kwargs.setdefault("marker", "o")
kwargs.setdefault("markersize", 2)

data = self.bs_plot_data()
for d in range(len(data["distances"])):
for i in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["gruneisen"][d][i][j] for j in range(len(data["distances"][d]))],
"b-",
# linewidth=band_linewidth)
marker="o",
markersize=2,
linewidth=2,
)
for dist_idx in range(len(data["distances"])):
for band_idx in range(self._nb_bands):
ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))]

ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs)

self._make_ticks(ax)

Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,9 @@ def test_relax_m3gnet(self):
relaxed = struct.relax()
assert relaxed.lattice.a == approx(3.867626620642243, abs=0.039) # 1% error
assert hasattr(relaxed, "calc")
assert relaxed.dynamics == {"type": "optimization", "optimizer": "FIRE"}
for key, val in {"type": "optimization", "optimizer": "FIRE"}.items():
actual = relaxed.dynamics[key]
assert actual == val, f"expected {key} to be {val}, {actual=}"

def test_relax_m3gnet_fixed_lattice(self):
pytest.importorskip("matgl")
Expand Down