diff --git a/.github/workflows/jekyll-gh-pages.yml b/.github/workflows/jekyll-gh-pages.yml index 6471481e99c..91f17b3ea39 100644 --- a/.github/workflows/jekyll-gh-pages.yml +++ b/.github/workflows/jekyll-gh-pages.yml @@ -1,15 +1,11 @@ -# Sample workflow for building and deploying a Jekyll site to GitHub Pages name: Deploy Jekyll with GitHub Pages dependencies preinstalled on: - # Runs on pushes targeting the default branch push: branches: ["master"] + workflow_dispatch: # enable manual workflow execution - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages permissions: contents: read pages: write @@ -22,7 +18,6 @@ concurrency: cancel-in-progress: false jobs: - # Build job build: # prevent this action from running on forks if: github.repository == 'materialsproject/pymatgen' @@ -30,17 +25,19 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + - name: Setup Pages uses: actions/configure-pages@v3 + - name: Build with Jekyll uses: actions/jekyll-build-pages@v1 with: source: ./docs destination: ./_site + - name: Upload artifact uses: actions/upload-pages-artifact@v2 - # Deployment job deploy: environment: name: github-pages diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 07e406823fc..6f4be3704a0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -109,7 +109,7 @@ jobs: uv pip install -e '.[dev,optional]' --system # TODO remove next line installing ase from main branch when FrechetCellFilter is released - uv pip install --upgrade 'ase@git+https://gitlab.com/ase/ase' --system + uv pip install --upgrade 'git+https://gitlab.com/ase/ase' --system - name: pytest split ${{ matrix.split }} run: | diff --git a/pymatgen/core/lattice.py b/pymatgen/core/lattice.py index acae47500e7..f3e781a8453 100644 --- a/pymatgen/core/lattice.py +++ b/pymatgen/core/lattice.py @@ -9,7 +9,7 @@ import warnings from fractions import Fraction from functools import reduce -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from monty.dev import deprecated @@ -26,6 +26,7 @@ from typing_extensions import Self from pymatgen.core.trajectory import Vector3D + from pymatgen.util.typing import PbcLike __author__ = "Shyue Ping Ong, Michael Kocher" __copyright__ = "Copyright 2011, The Materials Project" @@ -40,8 +41,7 @@ class Lattice(MSONable): """ # Properties lazily generated for efficiency. - - def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True, True)) -> None: + def __init__(self, matrix: ArrayLike, pbc: PbcLike = (True, True, True)) -> None: """Create a lattice from any sequence of 9 numbers. Note that the sequence is assumed to be read one row at a time. Each row represents one lattice vector. @@ -57,7 +57,7 @@ def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True E.g., [[10, 0, 0], [20, 10, 0], [0, 0, 30]] specifies a lattice with lattice vectors [10, 0, 0], [20, 10, 0] and [0, 0, 30]. pbc: a tuple defining the periodic boundary conditions along the three - axis of the lattice. If None periodic in all directions. + axis of the lattice. """ mat = np.array(matrix, dtype=np.float64).reshape((3, 3)) mat.setflags(write=False) @@ -66,7 +66,13 @@ def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True self._diags = None self._lll_matrix_mappings: dict[float, tuple[np.ndarray, np.ndarray]] = {} self._lll_inverse = None - self._pbc = tuple(pbc) + if len(pbc) != 3 or {*pbc} - {True, False}: + raise ValueError(f"pbc must be a tuple of three True/False values, got {pbc}") + + # don't import module-level, causes circular import with util/typing.py + from pymatgen.util.typing import PbcLike + + self._pbc = cast(PbcLike, tuple(pbc)) @property def lengths(self) -> Vector3D: @@ -84,13 +90,12 @@ def angles(self) -> Vector3D: Returns: The angles (alpha, beta, gamma) of the lattice. """ - mat = self._matrix - lengths = self.lengths + matrix, lengths = self._matrix, self.lengths angles = np.zeros(3) for dim in range(3): - j = (dim + 1) % 3 - k = (dim + 2) % 3 - angles[dim] = np.clip(np.dot(mat[j], mat[k]) / (lengths[j] * lengths[k]), -1, 1) + jj = (dim + 1) % 3 + kk = (dim + 2) % 3 + angles[dim] = np.clip(np.dot(matrix[jj], matrix[kk]) / (lengths[jj] * lengths[kk]), -1, 1) angles = np.arccos(angles) * 180.0 / np.pi return tuple(angles.tolist()) # type: ignore @@ -135,9 +140,9 @@ def matrix(self) -> np.ndarray: return self._matrix @property - def pbc(self) -> tuple[bool, bool, bool]: + def pbc(self) -> PbcLike: """Tuple defining the periodicity of the Lattice.""" - return self._pbc # type: ignore + return self._pbc @property def is_3d_periodic(self) -> bool: @@ -207,12 +212,12 @@ def d_hkl(self, miller_index: ArrayLike) -> float: Returns: d_hkl (float) """ - gstar = self.reciprocal_lattice_crystallographic.metric_tensor + g_star = self.reciprocal_lattice_crystallographic.metric_tensor hkl = np.array(miller_index) - return 1 / ((np.dot(np.dot(hkl, gstar), hkl.T)) ** (1 / 2)) + return 1 / ((np.dot(np.dot(hkl, g_star), hkl.T)) ** (1 / 2)) - @staticmethod - def cubic(a: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def cubic(cls, a: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a cubic lattice. Args: @@ -223,10 +228,10 @@ def cubic(a: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattic Returns: Cubic lattice of dimensions a x a x a. """ - return Lattice([[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]], pbc) + return cls([[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]], pbc) - @staticmethod - def tetragonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def tetragonal(cls, a: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a tetragonal lattice. Args: @@ -238,10 +243,10 @@ def tetragonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, T Returns: Tetragonal lattice of dimensions a x a x c. """ - return Lattice.from_parameters(a, a, c, 90, 90, 90, pbc=pbc) + return cls.from_parameters(a, a, c, 90, 90, 90, pbc=pbc) - @staticmethod - def orthorhombic(a: float, b: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def orthorhombic(cls, a: float, b: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for an orthorhombic lattice. Args: @@ -254,18 +259,16 @@ def orthorhombic(a: float, b: float, c: float, pbc: tuple[bool, bool, bool] = (T Returns: Orthorhombic lattice of dimensions a x b x c. """ - return Lattice.from_parameters(a, b, c, 90, 90, 90, pbc=pbc) + return cls.from_parameters(a, b, c, 90, 90, 90, pbc=pbc) - @staticmethod - def monoclinic( - a: float, b: float, c: float, beta: float, pbc: tuple[bool, bool, bool] = (True, True, True) - ) -> Lattice: + @classmethod + def monoclinic(cls, a: float, b: float, c: float, beta: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a monoclinic lattice. Args: - a (float): *a* lattice parameter of the monoclinc cell. - b (float): *b* lattice parameter of the monoclinc cell. - c (float): *c* lattice parameter of the monoclinc cell. + a (float): *a* lattice parameter of the monoclinic cell. + b (float): *b* lattice parameter of the monoclinic cell. + c (float): *c* lattice parameter of the monoclinic cell. beta (float): *beta* angle between lattice vectors b and c in degrees. pbc (tuple): a tuple defining the periodic boundary conditions along the three @@ -275,10 +278,10 @@ def monoclinic( Monoclinic lattice of dimensions a x b x c with non right-angle beta between lattice vectors a and c. """ - return Lattice.from_parameters(a, b, c, 90, beta, 90, pbc=pbc) + return cls.from_parameters(a, b, c, 90, beta, 90, pbc=pbc) - @staticmethod - def hexagonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def hexagonal(cls, a: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a hexagonal lattice. Args: @@ -290,10 +293,10 @@ def hexagonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, Tr Returns: Hexagonal lattice of dimensions a x a x c. """ - return Lattice.from_parameters(a, a, c, 90, 90, 120, pbc=pbc) + return cls.from_parameters(a, a, c, 90, 90, 120, pbc=pbc) - @staticmethod - def rhombohedral(a: float, alpha: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def rhombohedral(cls, a: float, alpha: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a rhombohedral lattice. Args: @@ -305,7 +308,7 @@ def rhombohedral(a: float, alpha: float, pbc: tuple[bool, bool, bool] = (True, T Returns: Rhombohedral lattice of dimensions a x a x a. """ - return Lattice.from_parameters(a, a, a, alpha, alpha, alpha, pbc=pbc) + return cls.from_parameters(a, a, a, alpha, alpha, alpha, pbc=pbc) @classmethod def from_parameters( @@ -317,8 +320,8 @@ def from_parameters( beta: float, gamma: float, vesta: bool = False, - pbc: tuple[bool, bool, bool] = (True, True, True), - ): + pbc: PbcLike = (True, True, True), + ) -> Self: """Create a Lattice using unit cell lengths (in Angstrom) and angles (in degrees). Args: @@ -360,10 +363,10 @@ def from_parameters( ] vector_c = [0.0, 0.0, float(c)] - return Lattice([vector_a, vector_b, vector_c], pbc) + return cls([vector_a, vector_b, vector_c], pbc) @classmethod - def from_dict(cls, dct: dict, fmt: str | None = None, **kwargs) -> Self: + def from_dict(cls, dct: dict, fmt: str | None = None, **kwargs) -> Self: # type: ignore[override] """Create a Lattice from a dictionary containing the a, b, c, alpha, beta, and gamma parameters if fmt is None. @@ -437,20 +440,22 @@ def params_dict(self) -> dict[str, float]: return dict(zip("a b c alpha beta gamma".split(), self.parameters)) @property - def reciprocal_lattice(self) -> Lattice: + def reciprocal_lattice(self) -> Self: """Return the reciprocal lattice. Note that this is the standard reciprocal lattice used for solid state physics with a factor of 2 * pi. If you are looking for the crystallographic reciprocal lattice, use the reciprocal_lattice_crystallographic property. The property is lazily generated for efficiency. """ - v = np.linalg.inv(self._matrix).T - return Lattice(v * 2 * np.pi) + inv_mat = np.linalg.inv(self._matrix).T + cls = type(self) + return cls(inv_mat * 2 * np.pi) @property - def reciprocal_lattice_crystallographic(self) -> Lattice: + def reciprocal_lattice_crystallographic(self) -> Self: """Returns the *crystallographic* reciprocal lattice, i.e. no factor of 2 * pi.""" - return Lattice(self.reciprocal_lattice.matrix / (2 * np.pi)) + cls = type(self) + return cls(self.reciprocal_lattice.matrix / (2 * np.pi)) @property def lll_matrix(self) -> np.ndarray: @@ -958,17 +963,19 @@ def find_mapping( """ return next(self.find_all_mappings(other_lattice, ltol, atol, skip_rotation_matrix), None) - def get_lll_reduced_lattice(self, delta: float = 0.75) -> Lattice: - """ + def get_lll_reduced_lattice(self, delta: float = 0.75) -> Self: + """Lenstra-Lenstra-Lovasz lattice basis reduction. + Args: delta: Delta parameter. Returns: - LLL reduced Lattice. + Lattice: LLL reduced """ if delta not in self._lll_matrix_mappings: self._lll_matrix_mappings[delta] = self._calculate_lll() - return Lattice(self._lll_matrix_mappings[delta][0]) + cls = type(self) + return cls(self._lll_matrix_mappings[delta][0]) def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]: """Performs a Lenstra-Lenstra-Lovasz lattice basis reduction to obtain a @@ -1093,14 +1100,14 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: if B + e < A or (abs(A - B) < e and abs(E) > abs(N) + e): # A1 - M = [[0, -1, 0], [-1, 0, 0], [0, 0, -1]] + M = np.array([[0, -1, 0], [-1, 0, 0], [0, 0, -1]]) G = np.dot(np.transpose(M), np.dot(G, M)) # update lattice parameters based on new G (gh-3657) A, B, C, E, N, Y = G[0, 0], G[1, 1], G[2, 2], 2 * G[1, 2], 2 * G[0, 2], 2 * G[0, 1] if (C + e < B) or (abs(B - C) < e and abs(N) > abs(Y) + e): # A2 - M = [[-1, 0, 0], [0, 0, -1], [0, -1, 0]] + M = np.array([[-1, 0, 0], [0, 0, -1], [0, -1, 0]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue @@ -1134,25 +1141,25 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: # A5 if abs(E) > B + e or (abs(E - B) < e and Y - e > 2 * N) or (abs(E + B) < e and -e > Y): - M = [[1, 0, 0], [0, 1, -E / abs(E)], [0, 0, 1]] + M = np.array([[1, 0, 0], [0, 1, -E / abs(E)], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A6 if abs(N) > A + e or (abs(A - N) < e and Y - e > 2 * E) or (abs(A + N) < e and -e > Y): - M = [[1, 0, -N / abs(N)], [0, 1, 0], [0, 0, 1]] + M = np.array([[1, 0, -N / abs(N)], [0, 1, 0], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A7 if abs(Y) > A + e or (abs(A - Y) < e and N - e > 2 * E) or (abs(A + Y) < e and -e > N): - M = [[1, -Y / abs(Y), 0], [0, 1, 0], [0, 0, 1]] + M = np.array([[1, -Y / abs(Y), 0], [0, 1, 0], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A8 if -e > E + N + Y + A + B or (abs(E + N + Y + A + B) < e < Y + (A + N) * 2): - M = [[1, 0, 1], [0, 1, 1], [0, 0, 1]] + M = np.array([[1, 0, 1], [0, 1, 1], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue @@ -1170,7 +1177,6 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: alpha = math.acos(E / 2 / b / c) / math.pi * 180 beta = math.acos(N / 2 / a / c) / math.pi * 180 gamma = math.acos(Y / 2 / a / b) / math.pi * 180 - lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma) mapped = self.find_mapping(lattice, e, skip_rotation_matrix=True) @@ -1181,7 +1187,7 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: raise ValueError("can't find niggli") - def scale(self, new_volume: float) -> Lattice: + def scale(self, new_volume: float) -> Self: """Return a new Lattice with volume new_volume by performing a scaling of the lattice vectors so that length proportions and angles are preserved. @@ -1201,7 +1207,7 @@ def scale(self, new_volume: float) -> Lattice: new_c = (new_volume / (geo_factor * np.prod(ratios))) ** (1 / 3.0) - return Lattice(versors * (new_c * ratios), pbc=self.pbc) + return type(self)(versors * (new_c * ratios), pbc=self.pbc) def get_wigner_seitz_cell(self) -> list[list[np.ndarray]]: """Returns the Wigner-Seitz cell for the given lattice. @@ -1692,7 +1698,7 @@ def get_points_in_spheres( all_coords: np.ndarray, center_coords: np.ndarray, r: float, - pbc: bool | list[bool] | tuple[bool, bool, bool] = True, + pbc: bool | list[bool] | PbcLike = True, numerical_tol: float = 1e-8, lattice: Lattice | None = None, return_fcoords: bool = False, @@ -1802,16 +1808,16 @@ def get_points_in_spheres( nn_coords = np.concatenate([cube_to_coords[k] for k in ks], axis=0) nn_images = itertools.chain(*(cube_to_images[k] for k in ks)) nn_indices = itertools.chain(*(cube_to_indices[k] for k in ks)) - dist = np.linalg.norm(nn_coords - ii[None, :], axis=1) + distances = np.linalg.norm(nn_coords - ii[None, :], axis=1) nns: list[tuple[np.ndarray, float, int, np.ndarray]] = [] - for coord, index, image, d in zip(nn_coords, nn_indices, nn_images, dist): + for coord, index, image, dist in zip(nn_coords, nn_indices, nn_images, distances): # filtering out all sites that are beyond the cutoff # Here there is no filtering of overlapping sites - if d < r + numerical_tol: + if dist < r + numerical_tol: if return_fcoords and (lattice is not None): coord = np.round(lattice.get_fractional_coords(coord), 10) - nn = (coord, float(d), int(index), image) - nns.append(nn) + nn = (coord, float(dist), int(index), image) + nns.append(nn) # type: ignore[arg-type] neighbors.append(nns) return neighbors diff --git a/pymatgen/util/typing.py b/pymatgen/util/typing.py index f5fa5a7c0bf..7420fbf26d1 100644 --- a/pymatgen/util/typing.py +++ b/pymatgen/util/typing.py @@ -5,7 +5,7 @@ from __future__ import annotations -from pathlib import Path +from os import PathLike as OsPathLike from typing import TYPE_CHECKING, Any, Union from pymatgen.core import Composition, DummySpecies, Element, Species @@ -18,7 +18,8 @@ from pymatgen.entries.exp_entries import ExpEntry -PathLike = Union[str, Path] +PathLike = Union[str, OsPathLike] +PbcLike = tuple[bool, bool, bool] # Things that can be cast to a Species-like object using get_el_sp SpeciesLike = Union[str, Element, Species, DummySpecies] diff --git a/tests/core/test_lattice.py b/tests/core/test_lattice.py index 578abe643fb..ee342cf16c1 100644 --- a/tests/core/test_lattice.py +++ b/tests/core/test_lattice.py @@ -24,19 +24,25 @@ def setUp(self): self.cubic_partial_pbc = Lattice.cubic(10.0, pbc=(True, True, False)) - family_names = [ - "cubic", - "tetragonal", - "orthorhombic", - "monoclinic", - "hexagonal", - "rhombohedral", - ] - self.families = {} - for name in family_names: + for name in ("cubic", "tetragonal", "orthorhombic", "monoclinic", "hexagonal", "rhombohedral"): self.families[name] = getattr(self, name) + def test_init(self): + len_a = 9.026 + lattice = Lattice.cubic(len_a) + assert lattice is not None, "Initialization from new_cubic failed" + assert_array_equal(lattice.pbc, (True, True, True)) + lattice2 = Lattice(np.eye(3) * len_a) + for ii in range(3): + for jj in range(3): + assert lattice.matrix[ii][jj] == lattice2.matrix[ii][jj], "Inconsistent matrix from two inits!" + assert_array_equal(self.cubic_partial_pbc.pbc, (True, True, False)) + + for bad_pbc in [(True, True), (True, True, True, True), (True, True, 2)]: + with pytest.raises(ValueError, match="pbc must be a tuple of three True/False values, got"): + Lattice(np.eye(3), pbc=bad_pbc) + def test_equal(self): assert self.cubic == self.cubic assert self.cubic == self.lattice @@ -57,17 +63,6 @@ def test_format(self): assert format(self.lattice, ".3f") == lattice_str assert format(self.lattice, ".1fp") == "{10.0, 10.0, 10.0, 90.0, 90.0, 90.0}" - def test_init(self): - len_a = 9.026 - lattice = Lattice.cubic(len_a) - assert lattice is not None, "Initialization from new_cubic failed" - assert_array_equal(lattice.pbc, (True, True, True)) - lattice2 = Lattice(np.eye(3) * len_a) - for ii in range(3): - for jj in range(3): - assert lattice.matrix[ii][jj] == lattice2.matrix[ii][jj], "Inconsistent matrix from two inits!" - assert_array_equal(self.cubic_partial_pbc.pbc, (True, True, False)) - def test_copy(self): cubic_copy = self.cubic.copy() assert cubic_copy == self.cubic @@ -110,8 +105,8 @@ def test_get_vector_along_lattice_directions(self): def test_d_hkl(self): cubic_copy = self.cubic.copy() hkl = (1, 2, 3) - dhkl = ((hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2) / (cubic_copy.a**2)) ** (-1 / 2) - assert dhkl == cubic_copy.d_hkl(hkl) + d_hkl = ((hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2) / (cubic_copy.a**2)) ** (-1 / 2) + assert d_hkl == cubic_copy.d_hkl(hkl) def test_reciprocal_lattice(self): recip_latt = self.lattice.reciprocal_lattice @@ -123,8 +118,8 @@ def test_reciprocal_lattice(self): ) # Test the crystallographic version. - recip_latt_xtal = self.lattice.reciprocal_lattice_crystallographic - assert_allclose(recip_latt.matrix, recip_latt_xtal.matrix * 2 * np.pi, 5) + recip_latt_crystallographic = self.lattice.reciprocal_lattice_crystallographic + assert_allclose(recip_latt.matrix, recip_latt_crystallographic.matrix * 2 * np.pi, 5) def test_static_methods(self): expected_lengths = [3.840198, 3.84019885, 3.8401976] diff --git a/tests/util/test_typing.py b/tests/util/test_typing.py index eb1ba9dbf7b..6112e05a933 100644 --- a/tests/util/test_typing.py +++ b/tests/util/test_typing.py @@ -1,24 +1,35 @@ +"""This module tests types are as expected and can be imported without circular ImportError.""" + +# mypy: disable-error-code="misc" + from __future__ import annotations -from typing import Any +import sys +from pathlib import Path +from types import GenericAlias +from typing import Any, get_args -# pymatgen.entries needs to be imported before pymatgen.util.typing -# to avoid circular import. -from pymatgen.entries import Entry -from pymatgen.util.typing import CompositionLike, EntryLike, PathLike, SpeciesLike +import pytest -# This module tests types are as expected and can be imported without circular ImportError. +from pymatgen.core import Composition, DummySpecies, Element, Species +from pymatgen.entries import Entry +from pymatgen.util.typing import CompositionLike, EntryLike, PathLike, PbcLike, SpeciesLike __author__ = "Janosh Riebesell" __date__ = "2022-10-20" __email__ = "janosh@lbl.gov" +skip_below_py310 = pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher") + def _type_str(some_type: Any) -> str: return str(some_type).replace("typing.", "").replace("pymatgen.core.periodic_table.", "") def test_entry_like(): + # needs to be tested as string to avoid + # TypeError: issubclass() arg 2 must be a class, a tuple of classes, or a union + # since EntryLike is defined as Union[] of strings to avoid circular imports entries = ( "Entry", "ComputedEntry", @@ -36,16 +47,32 @@ def test_entry_like(): assert Entry.__name__ in str(EntryLike) +@skip_below_py310 def test_species_like(): - assert _type_str(SpeciesLike) == "Union[str, Element, Species, DummySpecies]" + assert isinstance("H", SpeciesLike) + assert isinstance(Element("H"), SpeciesLike) + assert isinstance(Species("H+"), SpeciesLike) + assert isinstance(DummySpecies("X"), SpeciesLike) +@skip_below_py310 def test_composition_like(): - assert ( - _type_str(CompositionLike) - == "Union[str, Element, Species, DummySpecies, dict, pymatgen.core.composition.Composition]" - ) + assert isinstance("H", CompositionLike) + assert isinstance(Element("H"), CompositionLike) + assert isinstance(Species("H+"), CompositionLike) + assert isinstance(Composition("H"), CompositionLike) + assert isinstance({"H": 1}, CompositionLike) + assert isinstance(DummySpecies("X"), CompositionLike) + + +def test_pbc_like(): + assert type(PbcLike) == GenericAlias + assert get_args(PbcLike) == (bool, bool, bool) -def test_path_like(): - assert _type_str(PathLike) == "Union[str, pathlib.Path]" +@skip_below_py310 +def test_pathlike(): + assert isinstance("path/to/file", PathLike) + assert isinstance(Path("path/to/file"), PathLike) + assert not isinstance(1, PathLike) + assert not isinstance(1.0, PathLike)