From c7038720c135ac703f51ac6a4b313e375d743fa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ha=CC=8Akon=20Wiik=20A=CC=8Anes?= Date: Sun, 15 Dec 2024 15:51:15 +0100 Subject: [PATCH 1/2] Update type hints in symmetry module to >= 3.10 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Håkon Wiik Ånes --- orix/quaternion/symmetry.py | 38 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/orix/quaternion/symmetry.py b/orix/quaternion/symmetry.py index 9f13ab31..3f519869 100644 --- a/orix/quaternion/symmetry.py +++ b/orix/quaternion/symmetry.py @@ -17,15 +17,19 @@ from __future__ import annotations -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING from diffpy.structure.spacegroups import GetSpaceGroup -import matplotlib.pyplot as plt +import matplotlib.figure as mfigure import numpy as np from orix.quaternion.rotation import Rotation from orix.vector import Vector3d +if TYPE_CHECKING: # pragma: no cover + from orix.quaternion import Orientation + from orix.vector import FundamentalSector + class Symmetry(Rotation): r"""The set of rotations comprising a point group. @@ -63,22 +67,22 @@ def order(self) -> int: @property def is_proper(self) -> bool: """Return whether this group contains only proper rotations.""" - return np.all(np.equal(self.improper, 0)) + return bool(np.all(np.equal(self.improper, 0))) @property - def subgroups(self) -> List[Symmetry]: + def subgroups(self) -> list[Symmetry]: """Return the list groups that are subgroups of this group.""" return [g for g in _groups if g._tuples <= self._tuples] @property - def proper_subgroups(self) -> List[Symmetry]: + def proper_subgroups(self) -> list[Symmetry]: """Return the list of proper groups that are subgroups of this group. """ return [g for g in self.subgroups if g.is_proper] @property - def proper_subgroup(self) -> Union[List[Symmetry], Symmetry]: + def proper_subgroup(self) -> Symmetry: """Return the largest proper group of this subgroup.""" subgroups = self.proper_subgroups if len(subgroups) == 0: @@ -152,7 +156,7 @@ def euler_fundamental_region(self) -> tuple: return region @property - def system(self) -> Union[str, None]: + def system(self) -> str | None: """Return which of the seven crystal systems this symmetry belongs to. @@ -187,7 +191,7 @@ def _tuples(self) -> set: return tuples @property - def fundamental_sector(self) -> "orix.vector.FundamentalSector": + def fundamental_sector(self) -> "FundamentalSector": """Return the fundamental sector describing the inverse pole figure given by the point group name. @@ -258,7 +262,7 @@ def fundamental_sector(self) -> "orix.vector.FundamentalSector": return fs @property - def _primary_axis_order(self) -> Union[int, None]: + def _primary_axis_order(self) -> int | None: """Return the order of primary rotation axis for the proper subgroup. @@ -351,7 +355,7 @@ def __and__(self, other: Symmetry) -> Symmetry: generators = [g for g in self.subgroups if g in other.subgroups] return Symmetry.from_generators(*generators) - def __hash__(self) -> hash: + def __hash__(self) -> int: return hash(self.name.encode() + self.data.tobytes() + self.improper.tobytes()) # ------------------------ Class methods ------------------------- # @@ -404,7 +408,7 @@ def from_generators(cls, *generators: Rotation) -> Symmetry: # --------------------- Other public methods --------------------- # - def get_axis_orders(self) -> Dict[Vector3d, int]: + def get_axis_orders(self) -> dict[Vector3d, int]: s = self[self.angle > 0] if s.size == 0: return {} @@ -413,7 +417,7 @@ def get_axis_orders(self) -> Dict[Vector3d, int]: for a, b in zip(*np.unique(s.axis.data, axis=0, return_counts=True)) } - def get_highest_order_axis(self) -> Tuple[Vector3d, np.ndarray]: + def get_highest_order_axis(self) -> tuple[Vector3d, np.ndarray]: axis_orders = self.get_axis_orders() if len(axis_orders) == 0: return Vector3d.zvector(), np.inf @@ -424,7 +428,7 @@ def get_highest_order_axis(self) -> Tuple[Vector3d, np.ndarray]: return axes, highest_order def fundamental_zone(self) -> Vector3d: - from orix.vector import AxAngle, SphericalRegion + from orix.vector import SphericalRegion symmetry = self.antipodal symmetry = symmetry[symmetry.angle > 0] @@ -459,10 +463,10 @@ def fundamental_zone(self) -> Vector3d: def plot( self, - orientation: "orix.quaternion.Orientation" = None, - reproject_scatter_kwargs: Optional[dict] = None, + orientation: "Orientation | None" = None, + reproject_scatter_kwargs: dict | None = None, **kwargs, - ) -> plt.Figure: + ) -> mfigure.Figure | None: """Stereographic projection of symmetry operations. The upper hemisphere of the stereographic projection is shown. @@ -824,7 +828,7 @@ def get_point_group(space_group_number: int, proper: bool = False) -> Symmetry: } -def _get_laue_group_name(name: str) -> Union[str, None]: +def _get_laue_group_name(name: str) -> str | None: if name in ["1", "-1"]: return "-1" elif name in ["2", "211", "121", "112", "m11", "1m1", "11m", "2/m"]: From a485687f8acdc0404ff0dc925d179314b2ee4855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ha=CC=8Akon=20Wiik=20A=CC=8Anes?= Date: Sun, 15 Dec 2024 16:12:28 +0100 Subject: [PATCH 2/2] Complete EDAX point group aliases, rename dictionary to something more descriptive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Håkon Wiik Ånes --- orix/crystal_map/phase_list.py | 6 +++--- orix/io/plugins/ang.py | 6 +++--- orix/quaternion/symmetry.py | 4 +++- orix/tests/io/test_ang.py | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/orix/crystal_map/phase_list.py b/orix/crystal_map/phase_list.py index 4e12185e..9ee09c48 100644 --- a/orix/crystal_map/phase_list.py +++ b/orix/crystal_map/phase_list.py @@ -31,10 +31,10 @@ import numpy as np from orix.quaternion.symmetry import ( + _EDAX_POINT_GROUP_ALIASES, Symmetry, _groups, get_point_group, - point_group_aliases, ) from orix.vector import Miller, Vector3d @@ -229,9 +229,9 @@ def point_group(self, value: int | str | Symmetry | None) -> None: if isinstance(value, int): value = str(value) if isinstance(value, str): - for correct, aliases in point_group_aliases.items(): + for key, aliases in _EDAX_POINT_GROUP_ALIASES.items(): if value in aliases: - value = correct + value = key break for point_group in _groups: if value == point_group.name: diff --git a/orix/io/plugins/ang.py b/orix/io/plugins/ang.py index c93c8b8c..6e770427 100644 --- a/orix/io/plugins/ang.py +++ b/orix/io/plugins/ang.py @@ -30,7 +30,7 @@ from orix import __version__ from orix.crystal_map import CrystalMap, PhaseList, create_coordinate_arrays from orix.quaternion import Rotation -from orix.quaternion.symmetry import point_group_aliases +from orix.quaternion.symmetry import _EDAX_POINT_GROUP_ALIASES __all__ = ["file_reader", "file_writer"] @@ -594,9 +594,9 @@ def _get_header_from_phases(xmap: CrystalMap) -> str: else: proper_point_group = phase.point_group.proper_subgroup point_group_name = proper_point_group.name - for key, alias in point_group_aliases.items(): + for key, aliases in _EDAX_POINT_GROUP_ALIASES.items(): if point_group_name == key: - point_group_name = alias[0] + point_group_name = aliases[0] break header += ( f"Phase {phase_id}\n" diff --git a/orix/quaternion/symmetry.py b/orix/quaternion/symmetry.py index 3f519869..ec3061ef 100644 --- a/orix/quaternion/symmetry.py +++ b/orix/quaternion/symmetry.py @@ -818,11 +818,13 @@ def get_point_group(space_group_number: int, proper: bool = False) -> Symmetry: # Point group alias mapping. This is needed because in EDAX TSL OIM # Analysis 7.2, e.g. point group 432 is entered as 43. # Used when reading a phase's point group from an EDAX ANG file header -point_group_aliases = { +_EDAX_POINT_GROUP_ALIASES = { "121": ["20"], "2/m": ["2"], "222": ["22"], "422": ["42"], + "321": ["32"], + "622": ["62"], "432": ["43"], "m-3m": ["m3m"], } diff --git a/orix/tests/io/test_ang.py b/orix/tests/io/test_ang.py index 293aa526..aa461639 100644 --- a/orix/tests/io/test_ang.py +++ b/orix/tests/io/test_ang.py @@ -608,7 +608,7 @@ def test_extra_phases(self, crystal_map, tmp_path, extra_phase_names): del pl[-1] assert xmap_reload.phases.names == pl.names - @pytest.mark.parametrize("point_group", ["432", "121", "222"]) + @pytest.mark.parametrize("point_group", ["432", "121", "222", "321", "622"]) def test_point_group_aliases(self, crystal_map, tmp_path, point_group): crystal_map.phases[0].point_group = point_group fname = tmp_path / "test_point_group_aliases.ang"