diff --git a/elastica/__init__.py b/elastica/__init__.py index a088c7230..bd25b92cc 100644 --- a/elastica/__init__.py +++ b/elastica/__init__.py @@ -42,6 +42,9 @@ HingeJoint, SelfContact, ) +from elastica.contact_forces import ( + NoContact, +) from elastica.callback_functions import CallBackBaseClass, ExportCallBack, MyCallBack from elastica.dissipation import ( DamperBase, @@ -54,6 +57,8 @@ from elastica.modules.constraints import Constraints from elastica.modules.forcing import Forcing from elastica.modules.damping import Damping +from elastica.modules.contact import Contact + from elastica.transformations import inv_skew_symmetrize from elastica.transformations import rotate from elastica._calculus import ( @@ -66,7 +71,7 @@ ) from elastica._linalg import levi_civita_tensor from elastica.utils import isqrt -from elastica.typing import RodType, SystemType +from elastica.typing import RodType, SystemType, AllowedContactType from elastica.timestepper import ( integrate, PositionVerlet, diff --git a/elastica/contact_forces.py b/elastica/contact_forces.py new file mode 100644 index 000000000..f9b9afa0d --- /dev/null +++ b/elastica/contact_forces.py @@ -0,0 +1,68 @@ +__doc__ = """ Numba implementation module containing contact between rods and rigid bodies and other rods rigid bodies or surfaces.""" + +from elastica.typing import SystemType, AllowedContactType +from elastica.rod import RodBase + + +class NoContact: + """ + This is the base class for contact applied between rod-like objects and allowed contact objects. + + Notes + ----- + Every new contact class must be derived + from NoContact class. + + """ + + def __init__(self): + """ + NoContact class does not need any input parameters. + """ + + def _check_order( + self, + system_one: SystemType, + system_two: AllowedContactType, + ) -> None: + """ + This checks the contact order between a SystemType object and an AllowedContactType object, the order should follow: Rod, Rigid body, Surface. + In NoContact class, this just checks if system_two is a rod then system_one must be a rod. + + + Parameters + ---------- + system_one + SystemType + system_two + AllowedContactType + """ + if issubclass(system_two.__class__, RodBase): + if not issubclass(system_one.__class__, RodBase): + raise TypeError( + "Systems provided to the contact class have incorrect order. \n" + " First system is {0} and second system is {1} . \n" + " If the first system is a rod, the second system can be a rod, rigid body or surface. \n" + " If the first system is a rigid body, the second system can be a rigid body or surface.".format( + system_one.__class__, system_two.__class__ + ) + ) + + def apply_contact( + self, + system_one: SystemType, + system_two: AllowedContactType, + ) -> None: + """ + Apply contact forces and torques between SystemType object and AllowedContactType object. + + In NoContact class, this routine simply passes. + + Parameters + ---------- + system_one : SystemType + Rod or rigid-body object + system_two : AllowedContactType + Rod, rigid-body, or surface object + """ + pass diff --git a/elastica/modules/__init__.py b/elastica/modules/__init__.py index 889bf69fd..a801adb73 100644 --- a/elastica/modules/__init__.py +++ b/elastica/modules/__init__.py @@ -10,3 +10,4 @@ from .forcing import Forcing from .callbacks import CallBacks from .damping import Damping +from .contact import Contact diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index bc26f0d83..c3c6de084 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -11,6 +11,7 @@ from elastica.rod import RodBase from elastica.rigidbody import RigidBodyBase +from elastica.surface import SurfaceBase from elastica.modules.memory_block import construct_memory_block_structures from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries @@ -54,7 +55,7 @@ def __init__(self): # We need to initialize our mixin classes super(BaseSystemCollection, self).__init__() # List of system types/bases that are allowed - self.allowed_sys_types = (RodBase, RigidBodyBase) + self.allowed_sys_types = (RodBase, RigidBodyBase, SurfaceBase) # List of systems to be integrated self._systems = [] # Flag Finalize: Finalizing twice will cause an error, diff --git a/elastica/modules/contact.py b/elastica/modules/contact.py new file mode 100644 index 000000000..a5aada9f7 --- /dev/null +++ b/elastica/modules/contact.py @@ -0,0 +1,175 @@ +__doc__ = """ +Contact +------- + +Provides the contact interface to apply contact forces between objects +(rods, rigid bodies, surfaces). +""" + +from elastica.typing import SystemType, AllowedContactType + + +class Contact: + """ + The Contact class is a module for applying contact between rod-like objects . To apply contact between rod-like objects, + the simulator class must be derived from the Contact class. + + Attributes + ---------- + _contacts: list + List of contact classes defined for rod-like objects. + """ + + def __init__(self): + self._contacts = [] + super(Contact, self).__init__() + self._feature_group_synchronize.append(self._call_contacts) + self._feature_group_finalize.append(self._finalize_contact) + + def detect_contact_between( + self, first_system: SystemType, second_system: AllowedContactType + ): + """ + This method adds contact detection between two objects using the selected contact class. + You need to input the two objects that are to be connected. + + Parameters + ---------- + first_system : SystemType + Rod or rigid body object + second_system : AllowedContactType + Rod, rigid body or surface object + + Returns + ------- + + """ + sys_idx = [None] * 2 + for i_sys, sys in enumerate((first_system, second_system)): + sys_idx[i_sys] = self._get_sys_idx_if_valid(sys) + + # Create _Contact object, cache it and return to user + _contact = _Contact(*sys_idx) + self._contacts.append(_contact) + + return _contact + + def _finalize_contact(self) -> None: + + # dev : the first indices stores the + # (first_rod_idx, second_rod_idx) + # to apply the contacts to + # Technically we can use another array but it its one more book-keeping + # step. Being lazy, I put them both in the same array + self._contacts[:] = [(*contact.id(), contact()) for contact in self._contacts] + + # check contact order + for ( + first_sys_idx, + second_sys_idx, + contact, + ) in self._contacts: + contact._check_order( + self._systems[first_sys_idx], + self._systems[second_sys_idx], + ) + + def _call_contacts(self, *args, **kwargs): + for ( + first_sys_idx, + second_sys_idx, + contact, + ) in self._contacts: + contact.apply_contact( + self._systems[first_sys_idx], + self._systems[second_sys_idx], + *args, + **kwargs, + ) + + +class _Contact: + """ + Contact module private class + + Attributes + ---------- + _first_sys_idx: int + _second_sys_idx: int + _contact_cls: list + *args + Variable length argument list. + **kwargs + Arbitrary keyword arguments. + """ + + def __init__( + self, + first_sys_idx: int, + second_sys_idx: int, + ) -> None: + """ + + Parameters + ---------- + first_sys_idx + second_sys_idx + """ + self.first_sys_idx = first_sys_idx + self.second_sys_idx = second_sys_idx + self._contact_cls = None + self._args = () + self._kwargs = {} + + def using(self, contact_cls: object, *args, **kwargs): + """ + This method is a module to set which contact class is used to apply contact + between user defined rod-like objects. + + Parameters + ---------- + contact_cls: object + User defined contact class. + *args + Variable length argument list + **kwargs + Arbitrary keyword arguments. + + Returns + ------- + + """ + from elastica.contact_forces import NoContact + + assert issubclass( + contact_cls, NoContact + ), "{} is not a valid contact class. Did you forget to derive from NoContact?".format( + contact_cls + ) + self._contact_cls = contact_cls + self._args = args + self._kwargs = kwargs + return self + + def id(self): + return ( + self.first_sys_idx, + self.second_sys_idx, + ) + + def __call__(self, *args, **kwargs): + if not self._contact_cls: + raise RuntimeError( + "No contacts provided to to establish contact between rod-like object id {0}" + " and {1}, but a Contact" + "was intended as per code. Did you forget to" + "call the `using` method?".format(*self.id()) + ) + + try: + return self._contact_cls(*self._args, **self._kwargs) + except (TypeError, IndexError): + raise TypeError( + r"Unable to construct contact class.\n" + r"Did you provide all necessary contact properties?" + ) diff --git a/elastica/surface/__init__.py b/elastica/surface/__init__.py new file mode 100644 index 000000000..ce755872e --- /dev/null +++ b/elastica/surface/__init__.py @@ -0,0 +1,2 @@ +__doc__ = """Surface classes""" +from elastica.surface.surface_base import SurfaceBase diff --git a/elastica/surface/surface_base.py b/elastica/surface/surface_base.py new file mode 100644 index 000000000..8d60d2f06 --- /dev/null +++ b/elastica/surface/surface_base.py @@ -0,0 +1,18 @@ +__doc__ = """Base class for surfaces""" + + +class SurfaceBase: + """ + Base class for all surfaces. + + Notes + ----- + All new surface classes must be derived from this SurfaceBase class. + + """ + + def __init__(self): + """ + SurfaceBase does not take any arguments. + """ + pass diff --git a/elastica/typing.py b/elastica/typing.py index b4382648f..d70a16433 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -1,7 +1,9 @@ from elastica.rod import RodBase from elastica.rigidbody import RigidBodyBase +from elastica.surface import SurfaceBase from typing import Type, Union RodType = Type[RodBase] SystemType = Union[RodType, Type[RigidBodyBase]] +AllowedContactType = Union[SystemType, Type[SurfaceBase]] diff --git a/poetry.lock b/poetry.lock index 427aa616b..f6b90c662 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1554,14 +1554,14 @@ files = [ [[package]] name = "requests" -version = "2.30.0" +version = "2.31.0" description = "Python HTTP for Humans." category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "requests-2.30.0-py3-none-any.whl", hash = "sha256:10e94cc4f3121ee6da529d358cdaeaff2f1c409cd377dbc72b825852f2f7e294"}, - {file = "requests-2.30.0.tar.gz", hash = "sha256:239d7d4458afcb28a692cdd298d87542235f4ca8d36d03a15bfc128a6559a2f4"}, + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] [package.dependencies] diff --git a/tests/test_modules/test_base_system.py b/tests/test_modules/test_base_system.py index 094f2a3ae..ebb3b7c3e 100644 --- a/tests/test_modules/test_base_system.py +++ b/tests/test_modules/test_base_system.py @@ -74,8 +74,16 @@ def test_extend_allowed_types(self, load_collection): from elastica.rod import RodBase from elastica.rigidbody import RigidBodyBase - - assert bsc.allowed_sys_types == (RodBase, RigidBodyBase, int, float, str) + from elastica.surface import SurfaceBase + + assert bsc.allowed_sys_types == ( + RodBase, + RigidBodyBase, + SurfaceBase, + int, + float, + str, + ) def test_extend_correctness(self, load_collection): """ diff --git a/tests/test_modules/test_connections.py b/tests/test_modules/test_connections.py index 59292e740..37e59d22b 100644 --- a/tests/test_modules/test_connections.py +++ b/tests/test_modules/test_connections.py @@ -4,6 +4,8 @@ from elastica.modules import Connections from elastica.modules.connections import _Connect +from numpy.testing import assert_allclose +from elastica.utils import Tolerance class TestConnect: @@ -172,13 +174,14 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: _ = connect() - assert "Unable to construct" in str(excinfo.value) + assert r"Unable to construct connection class.\n" + r"Did you provide all necessary joint properties?" == str(excinfo.value) class TestConnectionsMixin: from elastica.modules import BaseSystemCollection - class SystemCollectionWithConnectionsMixedin(BaseSystemCollection, Connections): + class SystemCollectionWithConnectionsMixin(BaseSystemCollection, Connections): pass # TODO fix link after new PR @@ -188,14 +191,10 @@ class MockRod(RodBase): def __init__(self, *args, **kwargs): self.n_elems = 3 # arbitrary number - # Connections assume that this promise is met - def __len__(self): - return 2 # a random number - @pytest.fixture(scope="function", params=[2, 10]) def load_system_with_connects(self, request): n_sys = request.param - sys_coll_with_connects = self.SystemCollectionWithConnectionsMixedin() + sys_coll_with_connects = self.SystemCollectionWithConnectionsMixin() for i_sys in range(n_sys): sys_coll_with_connects.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_connects @@ -224,51 +223,51 @@ def load_system_with_connects(self, request): def test_connect_with_illegal_index_throws( self, load_system_with_connects, sys_idx ): - scwc = load_system_with_connects + system_collection_with_connections = load_system_with_connects with pytest.raises(AssertionError) as excinfo: - scwc.connect(*sys_idx) + system_collection_with_connections.connect(*sys_idx) assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - scwc.connect(*[np.int_(x) for x in sys_idx]) + system_collection_with_connections.connect(*[np.int_(x) for x in sys_idx]) assert "exceeds number of" in str(excinfo.value) def test_connect_with_unregistered_system_throws(self, load_system_with_connects): - scwc = load_system_with_connects + system_collection_with_connections = load_system_with_connects # Register this rod mock_rod_registered = self.MockRod(5, 5, 5, 5) - scwc.append(mock_rod_registered) + system_collection_with_connections.append(mock_rod_registered) # Don't register this rod mock_rod = self.MockRod(2, 3, 4, 5) with pytest.raises(ValueError) as excinfo: - scwc.connect(mock_rod, mock_rod_registered) + system_collection_with_connections.connect(mock_rod, mock_rod_registered) assert "was not found, did you" in str(excinfo.value) # Switch arguments with pytest.raises(ValueError) as excinfo: - scwc.connect(mock_rod_registered, mock_rod) + system_collection_with_connections.connect(mock_rod_registered, mock_rod) assert "was not found, did you" in str(excinfo.value) def test_connect_with_illegal_system_throws(self, load_system_with_connects): - scwc = load_system_with_connects + system_collection_with_connections = load_system_with_connects # Register this rod mock_rod_registered = self.MockRod(5, 5, 5, 5) - scwc.append(mock_rod_registered) + system_collection_with_connections.append(mock_rod_registered) # Not a rod, but a list! mock_rod = [1, 2, 3, 5] with pytest.raises(TypeError) as excinfo: - scwc.connect(mock_rod, mock_rod_registered) + system_collection_with_connections.connect(mock_rod, mock_rod_registered) assert "not a sys" in str(excinfo.value) # Switch arguments with pytest.raises(TypeError) as excinfo: - scwc.connect(mock_rod_registered, mock_rod) + system_collection_with_connections.connect(mock_rod_registered, mock_rod) assert "not a sys" in str(excinfo.value) """ @@ -276,30 +275,32 @@ def test_connect_with_illegal_system_throws(self, load_system_with_connects): """ def test_connect_registers_and_returns_Connect(self, load_system_with_connects): - scwc = load_system_with_connects + system_collection_with_connections = load_system_with_connects mock_rod_one = self.MockRod(2, 3, 4, 5) - scwc.append(mock_rod_one) + system_collection_with_connections.append(mock_rod_one) mock_rod_two = self.MockRod(4, 5) - scwc.append(mock_rod_two) + system_collection_with_connections.append(mock_rod_two) - _mock_connect = scwc.connect(mock_rod_one, mock_rod_two) - assert _mock_connect in scwc._connections + _mock_connect = system_collection_with_connections.connect( + mock_rod_one, mock_rod_two + ) + assert _mock_connect in system_collection_with_connections._connections assert _mock_connect.__class__ == _Connect # check sane defaults provided for connection indices - assert None in _mock_connect.id() and None in _mock_connect.id() + assert _mock_connect.id()[2] is None and _mock_connect.id()[3] is None from elastica.joint import FreeJoint @pytest.fixture def load_rod_with_connects(self, load_system_with_connects): - scwc = load_system_with_connects + system_collection_with_connections = load_system_with_connects mock_rod_one = self.MockRod(2, 3, 4, 5) - scwc.append(mock_rod_one) + system_collection_with_connections.append(mock_rod_one) mock_rod_two = self.MockRod(5.0, 5.0) - scwc.append(mock_rod_two) + system_collection_with_connections.append(mock_rod_two) def mock_init(self, *args, **kwargs): pass @@ -310,28 +311,128 @@ def mock_init(self, *args, **kwargs): ) # Constrain any and all systems - scwc.connect(0, 1).using(MockConnect, 2, 42) # index based connect - scwc.connect(mock_rod_one, mock_rod_two).using( + system_collection_with_connections.connect(0, 1).using( + MockConnect, 2, 42 + ) # index based connect + system_collection_with_connections.connect(mock_rod_one, mock_rod_two).using( MockConnect, 2, 3 ) # system based connect - scwc.connect(0, mock_rod_one).using( + system_collection_with_connections.connect(0, mock_rod_one).using( MockConnect, 1, 2 ) # index/system based connect - return scwc, MockConnect + return system_collection_with_connections, MockConnect def test_connect_finalize_correctness(self, load_rod_with_connects): - scwc, connect_cls = load_rod_with_connects + system_collection_with_connections, connect_cls = load_rod_with_connects - scwc._finalize_connections() + system_collection_with_connections._finalize_connections() - for (fidx, sidx, fconnect, sconnect, connect) in scwc._connections: + for ( + fidx, + sidx, + fconnect, + sconnect, + connect, + ) in system_collection_with_connections._connections: assert type(fidx) is int assert type(sidx) is int assert fconnect is None assert sconnect is None assert type(connect) is connect_cls - def test_connect_call_on_systems(self): - # TODO Finish after the architecture is complete - pass + @pytest.fixture + def load_rod_with_connects_and_indices(self, load_system_with_connects): + system_collection_with_connections_and_indices = load_system_with_connects + + mock_rod_one = self.MockRod(1.0, 2.0, 3.0, 4.0) + mock_rod_one.position_collection = np.array( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [3.0, 0.0, 0.0]] + ) + mock_rod_one.velocity_collection = np.array( + [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]] + ) + mock_rod_one.external_forces = np.array( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ) + system_collection_with_connections_and_indices.append(mock_rod_one) + mock_rod_two = self.MockRod(1.0, 1.0) + mock_rod_two.position_collection = np.array( + [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [4.0, 0.0, 0.0], [6.0, 0.0, 0.0]] + ) + mock_rod_two.velocity_collection = np.array( + [[2.0, 0.0, 0.0], [2.0, 0.0, 0.0], [-2.0, 0.0, 0.0], [-2.0, 0.0, 0.0]] + ) + mock_rod_two.external_forces = np.array( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ) + system_collection_with_connections_and_indices.append(mock_rod_two) + + def mock_init(self, *args, **kwargs): + self.k = 1.0 + self.nu = 0.1 + + # in place class + MockConnect = type( + "MockConnect", (self.FreeJoint, object), {"__init__": mock_init} + ) + + # Constrain any and all systems + system_collection_with_connections_and_indices.connect( + mock_rod_one, mock_rod_two, 0, 0 + ).using( + MockConnect, 2, 42 + ) # with connection indices + return system_collection_with_connections_and_indices, MockConnect + + def test_connect_call_on_systems(self, load_rod_with_connects_and_indices): + ( + system_collection_with_connections_and_indices, + connect_cls, + ) = load_rod_with_connects_and_indices + + system_collection_with_connections_and_indices._finalize_connections() + system_collection_with_connections_and_indices._call_connections() + + for ( + fidx, + sidx, + fconnect, + sconnect, + connect, + ) in system_collection_with_connections_and_indices._connections: + end_distance_vector = ( + system_collection_with_connections_and_indices._systems[ + sidx + ].position_collection[..., sconnect] + - system_collection_with_connections_and_indices._systems[ + fidx + ].position_collection[..., fconnect] + ) + elastic_force = connect.k * end_distance_vector + + relative_velocity = ( + system_collection_with_connections_and_indices._systems[ + sidx + ].velocity_collection[..., sconnect] + - system_collection_with_connections_and_indices._systems[ + fidx + ].velocity_collection[..., fconnect] + ) + damping_force = connect.nu * relative_velocity + + contact_force = elastic_force + damping_force + + assert_allclose( + system_collection_with_connections_and_indices._systems[ + fidx + ].external_forces[..., fconnect], + contact_force, + ) + assert_allclose( + system_collection_with_connections_and_indices._systems[ + sidx + ].external_forces[..., sconnect], + -1 * contact_force, + atol=Tolerance.atol(), + ) diff --git a/tests/test_modules/test_contact.py b/tests/test_modules/test_contact.py new file mode 100644 index 000000000..ddbd1f7c2 --- /dev/null +++ b/tests/test_modules/test_contact.py @@ -0,0 +1,301 @@ +__doc__ = """ Test modules for contact """ +import numpy as np +import pytest + +from elastica.modules import Contact +from elastica.modules.contact import _Contact + + +class TestContact: + @pytest.fixture(scope="function") + def load_contact(self, request): + # contact between 15th and 23rd rod + return _Contact(15, 23) + + @pytest.mark.parametrize("illegal_contact", [int, list]) + def test_using_with_illegal_contact_throws_assertion_error( + self, load_contact, illegal_contact + ): + with pytest.raises(AssertionError) as excinfo: + load_contact.using(illegal_contact) + assert "{} is not a valid contact class. Did you forget to derive from NoContact?".format( + illegal_contact + ) == str( + excinfo.value + ) + + from elastica.contact_forces import NoContact + + # TODO Add other legal contact later + @pytest.mark.parametrize("legal_contact", [NoContact]) + def test_using_with_legal_contact(self, load_contact, legal_contact): + contact = load_contact + contact.using(legal_contact, 3, 4.0, "5", k=1, l_var="2", j=3.0) + + assert contact._contact_cls == legal_contact + assert contact._args == (3, 4.0, "5") + assert contact._kwargs == {"k": 1, "l_var": "2", "j": 3.0} + + def test_id(self, load_contact): + contact = load_contact + # This is purely for coverage purposes, no actual test + # since its a simple return + assert contact.id() == (15, 23) + + def test_call_without_setting_contact_throws_runtime_error(self, load_contact): + contact = load_contact + + with pytest.raises(RuntimeError) as excinfo: + contact() + assert "No contacts provided to to establish contact between rod-like object id {0}" + " and {1}, but a Contact" + "was intended as per code. Did you forget to" + "call the `using` method?".format(*contact.id()) == str(excinfo.value) + + def test_call_improper_args_throws(self, load_contact): + # Example of bad initiailization function + # This needs at least four args which the user might + # forget to pass later on + def mock_init(self, *args, **kwargs): + self.nu = args[3] # Need at least four args + self.k = kwargs.get("k") + + # in place class + MockContact = type( + "MockContact", (self.NoContact, object), {"__init__": mock_init} + ) + + # The user thinks 4.0 goes to nu, but we don't accept it because of error in + # construction og a Contact class + contact = load_contact + contact.using(MockContact, 4.0, k=1, l_var="2", j=3.0) + + # Actual test is here, this should not throw + with pytest.raises(TypeError) as excinfo: + _ = contact() + assert r"Unable to construct contact class.\n" + r"Did you provide all necessary contact properties?" == str(excinfo.value) + + +class TestContactMixin: + from elastica.modules import BaseSystemCollection + + class SystemCollectionWithContactMixin(BaseSystemCollection, Contact): + pass + + from elastica.rod import RodBase + from elastica.rigidbody import RigidBodyBase + from elastica.surface import SurfaceBase + + class MockRod(RodBase): + def __init__(self, *args, **kwargs): + self.n_elems = 3 # arbitrary number + + class MockRigidBody(RigidBodyBase): + def __init__(self, *args, **kwargs): + self.n_elems = 1 + + class MockSurface(SurfaceBase): + def __init__(self, *args, **kwargs): + self.n_facets = 1 + + @pytest.fixture(scope="function", params=[2, 10]) + def load_system_with_contacts(self, request): + n_sys = request.param + sys_coll_with_contacts = self.SystemCollectionWithContactMixin() + for i_sys in range(n_sys): + sys_coll_with_contacts.append(self.MockRod(2, 3, 4, 5)) + return sys_coll_with_contacts + + """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + and are here because of legacy reasons. I have not removed them because there + are Contacts require testing against multiple indices, which is still use + ful to cross-verify against. + + START + """ + + @pytest.mark.parametrize( + "sys_idx", + [ + (12, 3), + (3, 12), + (-12, 3), + (-3, 12), + (12, -3), + (-12, -3), + (3, -12), + (-3, -12), + ], + ) + def test_contact_with_illegal_index_throws( + self, load_system_with_contacts, sys_idx + ): + system_collection_with_contacts = load_system_with_contacts + + with pytest.raises(AssertionError) as excinfo: + system_collection_with_contacts.detect_contact_between(*sys_idx) + assert "exceeds number of" in str(excinfo.value) + + with pytest.raises(AssertionError) as excinfo: + system_collection_with_contacts.detect_contact_between( + *[np.int_(x) for x in sys_idx] + ) + assert "exceeds number of" in str(excinfo.value) + + def test_contact_with_unregistered_system_throws(self, load_system_with_contacts): + system_collection_with_contacts = load_system_with_contacts + + # Register this rod + mock_rod_registered = self.MockRod(5, 5, 5, 5) + system_collection_with_contacts.append(mock_rod_registered) + # Don't register this rod + mock_rod = self.MockRod(2, 3, 4, 5) + + with pytest.raises(ValueError) as excinfo: + system_collection_with_contacts.detect_contact_between( + mock_rod, mock_rod_registered + ) + assert "was not found, did you" in str(excinfo.value) + + # Switch arguments + with pytest.raises(ValueError) as excinfo: + system_collection_with_contacts.detect_contact_between( + mock_rod_registered, mock_rod + ) + assert "was not found, did you" in str(excinfo.value) + + def test_contact_with_illegal_system_throws(self, load_system_with_contacts): + system_collection_with_contacts = load_system_with_contacts + + # Register this rod + mock_rod_registered = self.MockRod(5, 5, 5, 5) + system_collection_with_contacts.append(mock_rod_registered) + + # Not a rod, but a list! + mock_rod = [1, 2, 3, 5] + + with pytest.raises(TypeError) as excinfo: + system_collection_with_contacts.detect_contact_between( + mock_rod, mock_rod_registered + ) + assert "not a sys" in str(excinfo.value) + + # Switch arguments + with pytest.raises(TypeError) as excinfo: + system_collection_with_contacts.detect_contact_between( + mock_rod_registered, mock_rod + ) + assert "not a sys" in str(excinfo.value) + + """ + END of testing BaseSystem calls + """ + + def test_contact_registers_and_returns_Contact(self, load_system_with_contacts): + system_collection_with_contacts = load_system_with_contacts + + mock_rod_one = self.MockRod(2, 3, 4, 5) + system_collection_with_contacts.append(mock_rod_one) + + mock_rod_two = self.MockRod(4, 5) + system_collection_with_contacts.append(mock_rod_two) + + _mock_contact = system_collection_with_contacts.detect_contact_between( + mock_rod_one, mock_rod_two + ) + assert _mock_contact in system_collection_with_contacts._contacts + assert _mock_contact.__class__ == _Contact + + from elastica.contact_forces import NoContact + + @pytest.fixture + def load_rod_with_contacts(self, load_system_with_contacts): + system_collection_with_contacts = load_system_with_contacts + + mock_rod_one = self.MockRod(2, 3, 4, 5) + system_collection_with_contacts.append(mock_rod_one) + mock_rod_two = self.MockRod(5.0, 5.0) + system_collection_with_contacts.append(mock_rod_two) + + def mock_init(self, *args, **kwargs): + pass + + # in place class + MockContact = type( + "MockContact", (self.NoContact, object), {"__init__": mock_init} + ) + + # Constrain any and all systems + system_collection_with_contacts.detect_contact_between(0, 1).using( + MockContact + ) # index based contact + system_collection_with_contacts.detect_contact_between( + mock_rod_one, mock_rod_two + ).using( + MockContact + ) # system based contact + system_collection_with_contacts.detect_contact_between(0, mock_rod_one).using( + MockContact + ) # index/system based contact + return system_collection_with_contacts, MockContact + + def test_contact_finalize_correctness(self, load_rod_with_contacts): + system_collection_with_contacts, contact_cls = load_rod_with_contacts + + system_collection_with_contacts._finalize_contact() + + for (fidx, sidx, contact) in system_collection_with_contacts._contacts: + assert type(fidx) is int + assert type(sidx) is int + assert type(contact) is contact_cls + + @pytest.fixture + def load_contact_objects_with_incorrect_order(self, load_system_with_contacts): + system_collection_with_contacts = load_system_with_contacts + + mock_rod = self.MockRod(2, 3, 4, 5) + system_collection_with_contacts.append(mock_rod) + mock_rigid_body = self.MockRigidBody(5.0, 5.0) + system_collection_with_contacts.append(mock_rigid_body) + + def mock_init(self, *args, **kwargs): + pass + + # in place class + MockContact = type( + "MockContact", (self.NoContact, object), {"__init__": mock_init} + ) + + # incorrect order contact + system_collection_with_contacts.detect_contact_between( + mock_rigid_body, mock_rod + ).using( + MockContact + ) # rigid body before rod + + return system_collection_with_contacts, MockContact + + def test_contact_check_order(self, load_contact_objects_with_incorrect_order): + ( + system_collection_with_contacts, + contact_cls, + ) = load_contact_objects_with_incorrect_order + + from elastica.rod import RodBase + from elastica.rigidbody import RigidBodyBase + + with pytest.raises(TypeError) as excinfo: + system_collection_with_contacts._finalize_contact() + assert "Systems provided to the contact class have incorrect order. \n" + " First system is {0} and second system is {1} . \n" + " If the first system is a rod, the second system can be a rod, rigid body or surface. \n" + " If the first system is a rigid body, the second system can be a rigid body or surface.".format( + RigidBodyBase, RodBase + ) in str( + excinfo.value + ) + + def test_contact_call_on_systems(self): + # TODO Finish when other contact classes are made + pass