Skip to content

Commit

Permalink
Merge pull request #186 from pyiron/get_cell
Browse files Browse the repository at this point in the history
Get cell
  • Loading branch information
samwaseda authored May 21, 2024
2 parents 631eda8 + 0d3a426 commit 5a860fc
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 2 deletions.
1 change: 1 addition & 0 deletions structuretoolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_wrapped_coordinates,
pymatgen_to_ase,
select_index,
get_cell,
)

# Visualize
Expand Down
1 change: 1 addition & 0 deletions structuretoolkit/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_vertical_length,
get_wrapped_coordinates,
select_index,
get_cell,
)
from structuretoolkit.common.pymatgen import (
ase_to_pymatgen,
Expand Down
35 changes: 34 additions & 1 deletion structuretoolkit/common/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ase.atoms import Atoms
from ase.data import atomic_numbers
from scipy.sparse import coo_matrix
from typing import Optional
from typing import Optional, Union


def get_extended_positions(
Expand Down Expand Up @@ -226,3 +226,36 @@ def apply_strain(
structure_copy.set_cell(cell, scale_atoms=True)
if return_box:
return structure_copy


def get_cell(cell: Union[Atoms, list, tuple, np.ndarray, float]):
"""
Get cell of an ase structure, or convert a float or a (3,)-array into a
orthogonal cell.
Args:
cell (Atoms|ndarray|list|float|tuple): Cell
Returns:
(3, 3)-array: Cell
"""
if isinstance(cell, Atoms):
return cell.cell
# Convert float into (3,)-array. No effect if it is (3,3)-array or
# (3,)-array. Raises error if the shape is not correct
try:
cell = cell * np.ones(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)

if np.shape(cell) == (3, 3):
return cell
# Convert (3,)-array into (3,3)-array. Raises error if the shape is wrong
try:
return cell * np.eye(3)
except ValueError:
raise ValueError(
f"Invalid cell type or shape: {type(cell).__name__}, {np.shape(cell)}"
)
4 changes: 3 additions & 1 deletion structuretoolkit/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Optional
from scipy.interpolate import interp1d

from structuretoolkit.common.helper import get_cell

__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
__copyright__ = (
"Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
Expand Down Expand Up @@ -166,7 +168,7 @@ def _get_box_skeleton(cell: np.ndarray):


def _draw_box_plotly(fig, structure, px, go):
cell = structure.cell
cell = get_cell(structure)
data = fig.data
for lines in _get_box_skeleton(cell):
fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
Expand Down
28 changes: 28 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

import unittest
import numpy as np
from ase.build import bulk
import structuretoolkit as stk


class TestHelpers(unittest.TestCase):
def test_get_cell(self):
self.assertEqual((3 * np.eye(3)).tolist(), stk.get_cell(3).tolist())
self.assertEqual(
([1, 2, 3] * np.eye(3)).tolist(), stk.get_cell([1, 2, 3]).tolist()
)
atoms = bulk("Fe")
self.assertEqual(
atoms.cell.tolist(), stk.get_cell(atoms).tolist()
)
with self.assertRaises(ValueError):
stk.get_cell(np.arange(4))
with self.assertRaises(ValueError):
stk.get_cell(np.ones((4, 3)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 5a860fc

Please sign in to comment.