Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get cell #186

Merged
merged 14 commits into from
May 21, 2024
1 change: 1 addition & 0 deletions structuretoolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_wrapped_coordinates,
pymatgen_to_ase,
select_index,
get_cell,
)

# Visualize
Expand Down
1 change: 1 addition & 0 deletions structuretoolkit/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_vertical_length,
get_wrapped_coordinates,
select_index,
get_cell,
)
from structuretoolkit.common.pymatgen import (
ase_to_pymatgen,
Expand Down
35 changes: 34 additions & 1 deletion structuretoolkit/common/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ase.atoms import Atoms
from ase.data import atomic_numbers
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import to clean up the code.

- from ase.data import atomic_numbers

This import is unused and should be removed to clean up the codebase.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from ase.data import atomic_numbers

from scipy.sparse import coo_matrix
from typing import Optional
from typing import Optional, Union


def get_extended_positions(
Expand Down Expand Up @@ -226,3 +226,36 @@ def apply_strain(
structure_copy.set_cell(cell, scale_atoms=True)
if return_box:
return structure_copy


def get_cell(cell: Union[Atoms, list, tuple, np.ndarray, float]):
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.
Args:
cell (Atoms|ndarray|list|float|tuple): Cell
Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)

if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)
Comment on lines +231 to +261
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor get_cell for clarity and efficiency.

247-        cell = cell * np.ones(3)
+        cell = np.full(3, cell)

Consider using np.full instead of multiplying by np.ones(3) for clarity and efficiency. This change was suggested in a previous review and is still applicable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def get_cell(cell: Union[Atoms, list, tuple, np.ndarray, float]):
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.
Args:
cell (Atoms|ndarray|list|float|tuple): Cell
Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)
if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)
try:
cell = np.full(3, cell)

4 changes: 3 additions & 1 deletion structuretoolkit/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Optional
from scipy.interpolate import interp1d

from structuretoolkit.common.helper import get_cell

__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
__copyright__ = (
"Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
Expand Down Expand Up @@ -166,7 +168,7 @@ def _get_box_skeleton(cell: np.ndarray):


def _draw_box_plotly(fig, structure, px, go):
cell = structure.cell
cell = get_cell(structure)
data = fig.data
for lines in _get_box_skeleton(cell):
fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
Expand Down
28 changes: 28 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

import unittest
import numpy as np
from ase.build import bulk
import structuretoolkit as stk


class TestHelpers(unittest.TestCase):
def test_get_cell(self):
self.assertEqual((3 * np.eye(3)).tolist(), stk.get_cell(3).tolist())
self.assertEqual(
([1, 2, 3] * np.eye(3)).tolist(), stk.get_cell([1, 2, 3]).tolist()
)
atoms = bulk("Fe")
self.assertEqual(
atoms.cell.tolist(), stk.get_cell(atoms).tolist()
)
with self.assertRaises(ValueError):
stk.get_cell(np.arange(4))
with self.assertRaises(ValueError):
stk.get_cell(np.ones((4, 3)))


if __name__ == "__main__":
unittest.main()
Loading