Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get cell #186

Merged
merged 14 commits into from
May 21, 2024
1 change: 1 addition & 0 deletions structuretoolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_wrapped_coordinates,
pymatgen_to_ase,
select_index,
get_cell,
)

# Visualize
Expand Down
1 change: 1 addition & 0 deletions structuretoolkit/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_vertical_length,
get_wrapped_coordinates,
select_index,
get_cell,
)
from structuretoolkit.common.pymatgen import (
ase_to_pymatgen,
Expand Down
37 changes: 36 additions & 1 deletion structuretoolkit/common/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ase.atoms import Atoms
from ase.data import atomic_numbers
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import to clean up the code.

- from ase.data import atomic_numbers

This import is unused and should be removed to clean up the codebase.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from ase.data import atomic_numbers

from scipy.sparse import coo_matrix
from typing import Optional
from typing import Optional, Union


def get_extended_positions(
Expand Down Expand Up @@ -226,3 +226,38 @@ def apply_strain(
structure_copy.set_cell(cell, scale_atoms=True)
if return_box:
return structure_copy


def get_cell(cell: Union[Atoms, list, np.ndarray, float]):
samwaseda marked this conversation as resolved.
Show resolved Hide resolved
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.

Args:
cell (Atoms|ndarray|list|float|tuple): Cell

Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
"cell must be a float, (3,)-ndarray/list/tuple or"
" (3,3)-ndarray/list/tuple"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guy gets both duplicated and doesn't have a friendly custom class to test against; you might be able to kill both birds with one stone following a pattern like:

class CustomError(ValueError):
    def __init__(self, /, msg="My Custom message"):
        super().__init__(msg)

And then raising that in both places

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now gave it a short try, but after I implemented what Code Rabbit suggested, it looked complicated to pass arguments there, and since in terms of the number of lines it doesn't change much, I kept Code Rabbit's suggestion


if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
"cell must be a float, (3,)-ndarray/list/tuple or"
" (3,3)-ndarray/list/tuple"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor get_cell for enhanced clarity and error handling.

231 def get_cell(cell: Union[Atoms, list, np.ndarray, float]):
232     """
233     Get cell of an ase structure, or convert a float or a (3,)-array into an
234     orthogonal cell.
235 
236     Args:
237         cell (Atoms|ndarray|list|float|tuple): Cell
238 
239     Returns:
240         (3, 3)-array: Cell
241     """
242     if isinstance(cell, Atoms):
243         return cell.cell
244     # Convert float into (3,)-array. No effect if it is (3,3)-array or
245     # (3,)-array. Raises error if the shape is not correct
246     try:
247         cell = cell * np.ones(3)
248     except ValueError:
249         raise ValueError(
250             "cell must be a float, (3,)-ndarray/list/tuple or"
251             " (3,3)-ndarray/list/tuple"
252         )
253 
254     if np.shape(cell) == (3, 3):
255         return cell
256     # Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
257     try:
258         return cell * np.eye(3)
259     except ValueError:
260         raise ValueError(
261             "cell must be a float, (3,)-ndarray/list/tuple or"
262             " (3,3)-ndarray/list/tuple"
263         )

Consider enhancing the error messages to include the actual received input type or shape for better debugging.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def get_cell(cell: Union[Atoms, list, np.ndarray, float]):
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.
Args:
cell (Atoms|ndarray|list|float|tuple): Cell
Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
"cell must be a float, (3,)-ndarray/list/tuple or"
" (3,3)-ndarray/list/tuple"
)
if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
"cell must be a float, (3,)-ndarray/list/tuple or"
" (3,3)-ndarray/list/tuple"
)
def get_cell(cell: Union[Atoms, list, np.ndarray, float]):
"""
Get cell of an ase structure, or convert a float or a (3,)-array into an
orthogonal cell.
Args:
cell (Atoms|ndarray|list|float|tuple): Cell
Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
f"cell must be a float, (3,)-ndarray/list/tuple or"
f" (3,3)-ndarray/list/tuple, but got {type(cell)} with shape {np.shape(cell)}"
)
if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
f"cell must be a float, (3,)-ndarray/list/tuple or"
f" (3,3)-ndarray/list/tuple, but got {type(cell)} with shape {np.shape(cell)}"
)

4 changes: 3 additions & 1 deletion structuretoolkit/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Optional
from scipy.interpolate import interp1d

from structuretoolkit.common.helper import get_cell

__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
__copyright__ = (
"Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
Expand Down Expand Up @@ -166,7 +168,7 @@ def _get_box_skeleton(cell: np.ndarray):


def _draw_box_plotly(fig, structure, px, go):
cell = structure.cell
cell = get_cell(structure)
data = fig.data
for lines in _get_box_skeleton(cell):
fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
Expand Down
28 changes: 28 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

import unittest
import numpy as np
from ase.build import bulk
import structuretoolkit as stk


class TestHelpers(unittest.TestCase):
def test_get_cell(self):
self.assertEqual((3 * np.eye(3)).tolist(), stk.get_cell(3).tolist())
self.assertEqual(
([1, 2, 3] * np.eye(3)).tolist(), stk.get_cell([1, 2, 3]).tolist()
)
atoms = bulk("Fe")
self.assertEqual(
atoms.cell.tolist(), stk.get_cell(atoms).tolist()
)
with self.assertRaises(ValueError):
stk.get_cell(np.arange(4))
with self.assertRaises(ValueError):
stk.get_cell(np.ones((4, 3)))


if __name__ == "__main__":
unittest.main()
Loading