Skip to content

Commit

Permalink
extend ruff linter (#315)
Browse files Browse the repository at this point in the history
* extend ruff linter

* fix broken comparison
  • Loading branch information
jan-janssen authored Feb 5, 2025
1 parent d1cc0a4 commit dd56d66
Show file tree
Hide file tree
Showing 20 changed files with 329 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: ruff
name: ruff lint
args: ["--select", "I", "--fix"]
args: ["--fix"]
files: ^structuretoolkit/
- id: ruff-format
name: ruff format
47 changes: 47 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,53 @@ include = ["structuretoolkit*"]
[tool.setuptools.dynamic]
version = {attr = "structuretoolkit.__version__"}

[tool.ruff]
exclude = [".ci_support", "tests", "setup.py", "_version.py"]

[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-comprehensions
"C4",
# eradicate
"ERA",
# pylint
"PL",
]
ignore = [
# ignore functions in argument defaults
"B008",
# ignore exception naming
"B904",
# ignore line-length violations
"E501",
# ignore equality comparisons for numpy arrays
"E712",
# ignore bare except
"E722",
# ignore ambiguous variable name
"E741",
# Too many arguments in function definition
"PLR0913",
# Magic value used in comparison
"PLR2004",
# Too many branches
"PLR0912",
# Too many statements
"PLR0915",
]

[tool.versioneer]
VCS = "git"
style = "pep440-pre"
Expand Down
58 changes: 58 additions & 0 deletions structuretoolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,62 @@

from . import _version

__all__ = [
"find_mic",
"find_solids",
"get_adaptive_cna_descriptors",
"get_average_of_unique_labels",
"get_centro_symmetry_descriptors",
"get_cluster_positions",
"get_delaunay_neighbors",
"get_diamond_structure_descriptors",
"get_distances_array",
"get_equivalent_atoms",
"get_interstitials",
"get_layers",
"get_mean_positions",
"get_neighborhood",
"get_neighbors",
"get_steinhardt_parameters",
"get_strain",
"get_symmetry",
"get_voronoi_neighbors",
"get_voronoi_vertices",
"get_voronoi_volumes",
"analyse_find_solids",
"analyse_cna_adaptive",
"analyse_centro_symmetry",
"cluster_positions",
"analyse_diamond_structure",
"analyse_phonopy_equivalent_atoms",
"get_steinhardt_parameter_structure",
"analyse_voronoi_volume",
"B2",
"C14",
"C15",
"C36",
"D03",
"create_mesh",
"get_grainboundary_info",
"get_high_index_surface_info",
"grainboundary",
"high_index_surface",
"sqs_structures",
"grainboundary_info",
"high_index_surface_info",
"grainboundary_build",
"get_sqs_structures",
"SymmetryError",
"apply_strain",
"ase_to_pymatgen",
"ase_to_pyscal",
"center_coordinates_in_unit_cell",
"get_cell",
"get_extended_positions",
"get_vertical_length",
"get_wrapped_coordinates",
"pymatgen_to_ase",
"select_index",
"plot3d",
]
__version__ = _version.get_versions()["version"]
36 changes: 36 additions & 0 deletions structuretoolkit/analyse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,39 @@ def get_ir_reciprocal_mesh(
is_shift=is_shift,
is_time_reversal=is_time_reversal,
)


__all__ = [
"find_mic",
"get_distances_array",
"soap_descriptor_per_atom",
"get_neighborhood",
"get_neighbors",
"get_equivalent_atoms",
"find_solids",
"get_adaptive_cna_descriptors",
"get_centro_symmetry_descriptors",
"get_diamond_structure_descriptors",
"get_steinhardt_parameters",
"get_voronoi_volumes",
"get_snap_descriptor_derivatives",
"get_snap_descriptor_names",
"get_snap_descriptors_per_atom",
"get_average_of_unique_labels",
"get_cluster_positions",
"get_delaunay_neighbors",
"get_interstitials",
"get_layers",
"get_mean_positions",
"get_voronoi_neighbors",
"get_voronoi_vertices",
"get_strain",
"get_ir_reciprocal_mesh",
"get_symmetry",
"symmetrize_vectors",
"group_points_by_symmetry",
"get_primitive_cell",
"get_spacegroup",
"get_symmetry_dataset",
"get_equivalent_points",
]
4 changes: 3 additions & 1 deletion structuretoolkit/analyse/dscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def soap_descriptor_per_atom(
rbf: str = "gto",
weighting: Optional[np.ndarray] = None,
average: str = "off",
compression: dict = {"mode": "off", "species_weighting": None},
compression: dict = None,
species: Optional[list] = None,
periodic: bool = True,
sparse: bool = False,
Expand Down Expand Up @@ -50,6 +50,8 @@ def soap_descriptor_per_atom(
"""
from dscribe.descriptors import SOAP

if compression is None:
compression = {"mode": "off", "species_weighting": None}
if species is None:
species = list(set(structure.get_chemical_symbols()))
periodic_soap = SOAP(
Expand Down
74 changes: 37 additions & 37 deletions structuretoolkit/analyse/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

import itertools
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
from ase.atoms import Atoms
Expand Down Expand Up @@ -104,11 +103,11 @@ def _set_mode(self, new_mode: str) -> None:
Raises:
KeyError: If the new mode is not found in the available modes.
"""
if new_mode not in self._mode.keys():
if new_mode not in self._mode:
raise KeyError(
f"{new_mode} not found. Available modes: {', '.join(self._mode.keys())}"
)
self._mode = {key: False for key in self._mode.keys()}
self._mode = {key: False for key in self._mode}
self._mode[new_mode] = True

def __repr__(self) -> str:
Expand Down Expand Up @@ -366,7 +365,7 @@ def _get_distances_and_indices(
num_neighbors: Optional[int] = None,
cutoff_radius: float = np.inf,
width_buffer: float = 1.2,
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Get the distances and indices of the neighbors for the given positions.
Expand Down Expand Up @@ -406,7 +405,8 @@ def _get_distances_and_indices(
warnings.warn(
"Number of neighbors found within the cutoff_radius is equal to (estimated) "
+ "num_neighbors. Increase num_neighbors (or set it to None) or "
+ "width_buffer to find all neighbors within cutoff_radius."
+ "width_buffer to find all neighbors within cutoff_radius.",
stacklevel=2,
)
self._extended_indices = indices.copy()
indices[distances < np.inf] = self._get_wrapped_indices()[
Expand Down Expand Up @@ -508,7 +508,8 @@ def _estimate_num_neighbors(
if num_neighbors > self.num_neighbors:
warnings.warn(
"Taking a larger search area after initialization has the risk of "
+ "missing neighborhood atoms"
+ "missing neighborhood atoms",
stacklevel=2,
)
return num_neighbors

Expand Down Expand Up @@ -632,15 +633,14 @@ def _check_width(self, width: float, pbc: list[bool, bool, bool]) -> bool:
bool: True if the width exceeds the specified value, False otherwise.
"""
if any(pbc) and np.prod(self.filled.distances.shape) > 0:
if (
np.linalg.norm(
self.flattened.vecs[..., pbc], axis=-1, ord=self.norm_order
).max()
> width
):
return True
return False
return bool(
any(pbc)
and np.prod(self.filled.distances.shape) > 0
and np.linalg.norm(
self.flattened.vecs[..., pbc], axis=-1, ord=self.norm_order
).max()
> width
)

def get_spherical_harmonics(
self,
Expand Down Expand Up @@ -811,9 +811,9 @@ def __getattr__(self, name):
def __dir__(self):
"""Show value names which are available for different filling modes."""
return list(
set(
["distances", "vecs", "indices", "shells", "atom_numbers"]
).intersection(self.ref_neigh.__dir__())
{"distances", "vecs", "indices", "shells", "atom_numbers"}.intersection(
self.ref_neigh.__dir__()
)
)


Expand Down Expand Up @@ -1008,7 +1008,7 @@ def get_global_shells(

def get_shell_matrix(
self,
chemical_pair: Optional[List[str]] = None,
chemical_pair: Optional[list[str]] = None,
cluster_by_distances: bool = False,
cluster_by_vecs: bool = False,
):
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def reset_clusters(self, vecs: bool = True, distances: bool = True):

def cluster_analysis(
self, id_list: list, return_cluster_sizes: bool = False
) -> Union[Dict[int, List[int]], Tuple[Dict[int, List[int]], List[int]]]:
) -> Union[dict[int, list[int]], tuple[dict[int, list[int]], list[int]]]:
"""
Perform cluster analysis on a list of atom IDs.
Expand All @@ -1240,11 +1240,8 @@ def cluster_analysis(
"""
self._cluster = [0] * len(self._ref_structure)
c_count = 1
# element_list = self.get_atomic_numbers()
for ia in id_list:
# el0 = element_list[ia]
nbrs = self.ragged.indices[ia]
# print ("nbrs: ", ia, nbrs)
if self._cluster[ia] == 0:
self._cluster[ia] = c_count
self.__probe_cluster(c_count, nbrs, id_list)
Expand All @@ -1261,7 +1258,7 @@ def cluster_analysis(
return cluster_dict # sizes

def __probe_cluster(
self, c_count: int, neighbors: List[int], id_list: List[int]
self, c_count: int, neighbors: list[int], id_list: list[int]
) -> None:
"""
Recursively probe the cluster and assign cluster IDs to neighbors.
Expand All @@ -1275,19 +1272,20 @@ def __probe_cluster(
None
"""
for nbr_id in neighbors:
if self._cluster[nbr_id] == 0:
if nbr_id in id_list: # TODO: check also for ordered structures
self._cluster[nbr_id] = c_count
nbrs = self.ragged.indices[nbr_id]
self.__probe_cluster(c_count, nbrs, id_list)
if (
self._cluster[nbr_id] == 0 and nbr_id in id_list
): # TODO: check also for ordered structures
self._cluster[nbr_id] = c_count
nbrs = self.ragged.indices[nbr_id]
self.__probe_cluster(c_count, nbrs, id_list)

# TODO: combine with corresponding routine in plot3d
def get_bonds(
self,
radius: float = np.inf,
max_shells: Optional[int] = None,
prec: float = 0.1,
) -> List[Dict[str, List[List[int]]]]:
) -> list[dict[str, list[list[int]]]]:
"""
Get the bonds in the structure.
Expand All @@ -1303,7 +1301,7 @@ def get_bonds(

def get_cluster(
dist_vec: np.ndarray, ind_vec: np.ndarray, prec: float = prec
) -> List[np.ndarray]:
) -> list[np.ndarray]:
"""
Get clusters from a distance vector and index vector.
Expand All @@ -1326,7 +1324,6 @@ def get_cluster(
ind_shell = []
for d, i in zip(dist, ind):
id_list = get_cluster(d[d < radius], i[d < radius])
# print ("id: ", d[d<radius], id_list, dist_lst)
ia_shells_dict = {}
for i_shell_list in id_list:
ia_shell_dict = {}
Expand All @@ -1338,9 +1335,11 @@ def get_cluster(
for el, ia_lst in ia_shell_dict.items():
if el not in ia_shells_dict:
ia_shells_dict[el] = []
if max_shells is not None:
if len(ia_shells_dict[el]) + 1 > max_shells:
continue
if (
max_shells is not None
and len(ia_shells_dict[el]) + 1 > max_shells
):
continue
ia_shells_dict[el].append(ia_lst)
ind_shell.append(ia_shells_dict)
return ind_shell
Expand Down Expand Up @@ -1457,7 +1456,8 @@ def _get_neighbors(
if neigh._check_width(width=width, pbc=structure.pbc):
warnings.warn(
"width_buffer may have been too small - "
"most likely not all neighbors properly assigned"
"most likely not all neighbors properly assigned",
stacklevel=2,
)
return neigh

Expand Down
Loading

0 comments on commit dd56d66

Please sign in to comment.