From 62264fbe152b4286efb4fda1e250cfdacbfe8016 Mon Sep 17 00:00:00 2001 From: Alex Bilger Date: Mon, 10 Feb 2025 14:20:15 +0100 Subject: [PATCH] [StateContainer] Allow coord difference in vOp for rigids (#5253) * [StateContainer] Allow coord difference in vOp for rigids * check that DataTypes support coordDifference --- .../statecontainer/MechanicalObject.inl | 34 +++++++++++++++++-- .../tests/MechanicalObjectVOp_test.cpp | 30 ++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/Sofa/Component/StateContainer/src/sofa/component/statecontainer/MechanicalObject.inl b/Sofa/Component/StateContainer/src/sofa/component/statecontainer/MechanicalObject.inl index 7ccd7341dde..aa2b2dc6fd6 100644 --- a/Sofa/Component/StateContainer/src/sofa/component/statecontainer/MechanicalObject.inl +++ b/Sofa/Component/StateContainer/src/sofa/component/statecontainer/MechanicalObject.inl @@ -2098,9 +2098,37 @@ void MechanicalObject::vOp(const core::ExecParams* params, core::VecI } else { - // v = a+b*f - APPLY_PREDICATE_THREEPARAMS(vOp_vabf, v, a, b) - msg_error_when(!canApplyPredicate) << "Cannot apply vector operation v = a+b*f (" << v << ',' << a << ',' << b << ',' << f << ")"; + const auto generalCase = [&]() + { + // v = a+b*f + APPLY_PREDICATE_THREEPARAMS(vOp_vabf, v, a, b) + msg_error_when(!canApplyPredicate) << "Cannot apply vector operation v = a+b*f (" << v << ',' << a << ',' << b << ',' << f << ")"; + }; + + if constexpr (requires (Coord ca, Coord cb) {DataTypes::coordDifference(ca, cb);}) + { + if (f == -1._sreal && + (v.type == core::V_DERIV && a.type == core::V_COORD && b.type == core::V_COORD)) + { + // v = a-b + auto vv = this->getWriteOnlyAccessor(v); + auto va = this->getReadAccessor(a); + auto vb = this->getReadAccessor(b); + vv.resize(vb.size()); + for (unsigned int i = 0; i < vv.size(); ++i) + { + vv[i] = DataTypes::coordDifference(va[i], vb[i]); + } + } + else + { + generalCase(); + } + } + else + { + generalCase(); + } } } } diff --git a/Sofa/Component/StateContainer/tests/MechanicalObjectVOp_test.cpp b/Sofa/Component/StateContainer/tests/MechanicalObjectVOp_test.cpp index 811d4d22b9e..a7be26a616e 100644 --- a/Sofa/Component/StateContainer/tests/MechanicalObjectVOp_test.cpp +++ b/Sofa/Component/StateContainer/tests/MechanicalObjectVOp_test.cpp @@ -731,6 +731,31 @@ struct MechanicalObjectVOpTest : public testing::BaseTest checkVecValues(isRigid ? velocityCoefficient : forceCoefficient + positionCoefficient * 12); } + void equalCoordDifference() const + { + // v = a-b + m_mechanicalObject->vOp(nullptr, core::vec_id::write_access::velocity, core::vec_id::read_access::restPosition, core::vec_id::read_access::position, -1_sreal); + + unsigned int index {}; + auto vv = sofa::helper::getReadAccessor(*m_mechanicalObject->read(core::vec_id::read_access::velocity)); + auto va = sofa::helper::getReadAccessor(*m_mechanicalObject->read(core::vec_id::read_access::restPosition)); + auto vb = sofa::helper::getReadAccessor(*m_mechanicalObject->read(core::vec_id::read_access::position)); + + ASSERT_EQ(vv.size(), 10); + for (std::size_t i = 0; i < vv.size(); ++i) + { + const auto& v = vv[i]; + const auto diff = DataTypes::coordDifference(va[i], vb[i]); + for (std::size_t j = 0; j < v.size(); ++j) + { + EXPECT_FLOATINGPOINT_EQ(v[j], diff[j]) + } + ++index; + } + + checkVecValues(positionCoefficient); + } + typename MO::SPtr m_mechanicalObject; }; @@ -949,4 +974,9 @@ TYPED_TEST(MechanicalObjectVOpTest, equalSumWithScalarVelocityMix2) this->equalSumWithScalarVelocityMix2(); } +TYPED_TEST(MechanicalObjectVOpTest, equalCoordDifference) +{ + this->equalCoordDifference(); +} + } \ No newline at end of file