Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get cell #186

Merged
merged 14 commits into from
May 21, 2024
1 change: 1 addition & 0 deletions structuretoolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_wrapped_coordinates,
pymatgen_to_ase,
select_index,
get_cell,
)

# Visualize
Expand Down
1 change: 1 addition & 0 deletions structuretoolkit/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_vertical_length,
get_wrapped_coordinates,
select_index,
get_cell,
)
from structuretoolkit.common.pymatgen import (
ase_to_pymatgen,
Expand Down
31 changes: 30 additions & 1 deletion structuretoolkit/common/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ase.atoms import Atoms
from ase.data import atomic_numbers
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import to clean up the code.

- from ase.data import atomic_numbers

This import is unused and should be removed to clean up the codebase.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from ase.data import atomic_numbers

from scipy.sparse import coo_matrix
from typing import Optional
from typing import Optional, Union


def get_extended_positions(
Expand Down Expand Up @@ -226,3 +226,32 @@ def apply_strain(
structure_copy.set_cell(cell, scale_atoms=True)
if return_box:
return structure_copy


def get_cell(cell: Union[Atoms, list, np.ndarray, float]):
samwaseda marked this conversation as resolved.
Show resolved Hide resolved
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.

Args:
cell (Atoms|ndarray|list|float|tuple): Cell

Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
"cell must be a float, (3,)-ndarray/list/tuple or"
" (3,3)-ndarray/list/tuple"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guy gets both duplicated and doesn't have a friendly custom class to test against; you might be able to kill both birds with one stone following a pattern like:

class CustomError(ValueError):
    def __init__(self, /, msg="My Custom message"):
        super().__init__(msg)

And then raising that in both places

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now gave it a short try, but after I implemented what Code Rabbit suggested, it looked complicated to pass arguments there, and since in terms of the number of lines it doesn't change much, I kept Code Rabbit's suggestion


if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
return cell * np.eye(3)
samwaseda marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 26 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

import unittest
import numpy as np
from ase.build import bulk
import structuretoolkit as stk


class TestHelpers(unittest.TestCase):
def test_get_cell(self):
self.assertEqual((3 * np.eye(3)).tolist(), stk.get_cell(3).tolist())
self.assertEqual(
([1, 2, 3] * np.eye(3)).tolist(), stk.get_cell([1, 2, 3]).tolist()
)
atoms = bulk("Fe")
self.assertEqual(
atoms.cell.tolist(), stk.get_cell(atoms).tolist()
)
with self.assertRaises(ValueError):
stk.get_cell(np.arange(4))


if __name__ == "__main__":
unittest.main()
Loading