Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update casted collision evaluator to handle fixed start and end states #154

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 147 additions & 28 deletions trajopt/include/trajopt/collision_terms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@

namespace trajopt
{
/**
* @brief This contains the different types of expression evaluators used when performing continuous collision checking.
*/
enum class CollisionExpressionEvaluatorType
{
START_FREE_END_FREE = 0, /**< @brief Both start and end state variables are free to be adjusted */
START_FREE_END_FIXED = 1, /**< @brief Only start state variables are free to be adjusted */
START_FIXED_END_FREE = 2 /**< @brief Only end state variables are free to be adjusted */
};

/**
* @brief Base class for collision evaluators containing function that are commonly used between them.
*
* This class also facilitates the caching of the contact results to prevent collision checking from being called
* multiple times throughout the optimization.
*
*/
struct CollisionEvaluator
{
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
Expand All @@ -20,29 +37,58 @@ struct CollisionEvaluator
const Eigen::Isometry3d& world_to_base,
SafetyMarginData::ConstPtr safety_margin_data,
tesseract_collision::ContactTestType contact_test_type,
double longest_valid_segment_length)
: manip_(std::move(manip))
, env_(std::move(env))
, adjacency_map_(std::move(adjacency_map))
, world_to_base_(world_to_base)
, safety_margin_data_(std::move(safety_margin_data))
, contact_test_type_(contact_test_type)
, longest_valid_segment_length_(longest_valid_segment_length)
{
}
double longest_valid_segment_length);
virtual ~CollisionEvaluator() = default;
CollisionEvaluator(const CollisionEvaluator&) = default;
CollisionEvaluator& operator=(const CollisionEvaluator&) = default;
CollisionEvaluator(CollisionEvaluator&&) = default;
CollisionEvaluator& operator=(CollisionEvaluator&&) = default;

/**
* @brief This function calls GetCollisionsCached and stores the distances in a vector
* @param x Optimizer variables
* @param dists Returned distance values
*/
virtual void CalcDists(const DblVec& x, DblVec& dists);

/**
* @brief Convert the contact information into an affine expression
* @param x Optimizer variables
* @param exprs Returned affine expression representation of the contact information
*/
virtual void CalcDistExpressions(const DblVec& x, sco::AffExprVector& exprs) = 0;
virtual void CalcDists(const DblVec& x, DblVec& exprs) = 0;

/**
* @brief Given optimizer parameters calculate the collision results for this evaluator
* @param x Optimizer variables
* @param dist_results Contact results
*/
virtual void CalcCollisions(const DblVec& x, tesseract_collision::ContactResultVector& dist_results) = 0;

/**
* @brief This function checks to see if results are cached for input variable x. If not it calls CalcCollisions and
* caches the results with x as the key.
* @param x Optimizer variables
*/
void GetCollisionsCached(const DblVec& x, tesseract_collision::ContactResultVector&);

/**
* @brief Plot the collision evaluator results
* @param plotter Plotter
* @param x Optimizer variables
*/
virtual void Plot(const tesseract_visualization::Visualization::Ptr& plotter, const DblVec& x) = 0;

/**
* @brief Get the specific optimizer variables associated with this evaluator.
* @return Evaluators variables
*/
virtual sco::VarVector GetVars() = 0;

/**
* @brief Get the safety margin information.
* @return Safety margin information
*/
const SafetyMarginData::ConstPtr getSafetyMarginData() const { return safety_margin_data_; }
Cache<size_t, tesseract_collision::ContactResultVector, 10> m_cache;

Expand All @@ -54,11 +100,57 @@ struct CollisionEvaluator
SafetyMarginData::ConstPtr safety_margin_data_;
tesseract_collision::ContactTestType contact_test_type_;
double longest_valid_segment_length_;
tesseract_environment::StateSolver::Ptr state_solver_;
sco::VarVector vars0_;
sco::VarVector vars1_;
CollisionExpressionEvaluatorType evaluator_type_;

/**
* @brief Calculate the distance expressions when the start is free but the end is fixed
* @param x The current values
* @param exprs The returned expression
*/
void CalcDistExpressionsStartFree(const DblVec& x, sco::AffExprVector& exprs);

/**
* @brief Calculate the distance expressions when the end is free but the start is fixed
* @param x The current values
* @param exprs The returned expression
*/
void CalcDistExpressionsEndFree(const DblVec& x, sco::AffExprVector& exprs);

/**
* @brief Calculate the distance expressions when the start and end are free
* @param x The current values
* @param exprs The returned expression
*/
void CalcDistExpressionsBothFree(const DblVec& x, sco::AffExprVector& exprs);

/**
* @brief This takes contacts results at each interpolated timestep and creates a single contact results map.
* This also updates the cc_time and cc_type for the contact results
* @param contacts_vector Contact results map at each interpolated timestep
* @param contact_results The merged contact results map
*/
void processInterpolatedCollisionResults(std::vector<tesseract_collision::ContactResultMap>& contacts_vector,
tesseract_collision::ContactResultMap& contact_results) const;

/**
* @brief Remove any results that are invalid.
* Invalid state are contacts that occur at fixed states or have distances outside the threshold.
* @param contact_results Contact results vector to process.
*/
void removeInvalidContactResults(tesseract_collision::ContactResultVector& contact_results,
const Eigen::Vector2d& pair_data) const;

private:
CollisionEvaluator() = default;
};

/**
* @brief This collision evaluator only operates on a single state in the trajectory and does not check for collisions
* between states.
*/
struct SingleTimestepCollisionEvaluator : public CollisionEvaluator
{
public:
Expand All @@ -77,20 +169,18 @@ struct SingleTimestepCollisionEvaluator : public CollisionEvaluator
function
*/
void CalcDistExpressions(const DblVec& x, sco::AffExprVector& exprs) override;
/**
* Same as CalcDistExpressions, but just the distances--not the expressions
*/
void CalcDists(const DblVec& x, DblVec& dists) override;
void CalcCollisions(const DblVec& x, tesseract_collision::ContactResultVector& dist_results) override;
void Plot(const tesseract_visualization::Visualization::Ptr& plotter, const DblVec& x) override;
sco::VarVector GetVars() override { return m_vars; }
sco::VarVector GetVars() override { return vars0_; }

private:
sco::VarVector m_vars;
tesseract_collision::DiscreteContactManager::Ptr contact_manager_;
tesseract_environment::StateSolver::Ptr state_solver_;
};

/**
* @brief This collision evaluator operates on two states and checks for collision between the two states using a
* casted collision objects between to intermediate interpolated states.
*/
struct CastCollisionEvaluator : public CollisionEvaluator
{
public:
Expand All @@ -102,18 +192,43 @@ struct CastCollisionEvaluator : public CollisionEvaluator
tesseract_collision::ContactTestType contact_test_type,
double longest_valid_segment_length,
sco::VarVector vars0,
sco::VarVector vars1);
sco::VarVector vars1,
CollisionExpressionEvaluatorType type);
void CalcDistExpressions(const DblVec& x, sco::AffExprVector& exprs) override;
void CalcDists(const DblVec& x, DblVec& exprs) override;
void CalcCollisions(const DblVec& x, tesseract_collision::ContactResultVector& dist_results) override;
void Plot(const tesseract_visualization::Visualization::Ptr& plotter, const DblVec& x) override;
sco::VarVector GetVars() override { return concat(m_vars0, m_vars1); }
sco::VarVector GetVars() override { return concat(vars0_, vars1_); }

private:
sco::VarVector m_vars0;
sco::VarVector m_vars1;
tesseract_collision::ContinuousContactManager::Ptr contact_manager_;
tesseract_environment::StateSolver::Ptr state_solver_;
std::function<void(const DblVec&, sco::AffExprVector&)> fn_;
};

/**
* @brief This collision evaluator operates on two states and checks for collision between the two states using a
* descrete collision objects at each intermediate interpolated states.
*/
struct DiscreteCollisionEvaluator : public CollisionEvaluator
{
public:
DiscreteCollisionEvaluator(tesseract_kinematics::ForwardKinematics::ConstPtr manip,
tesseract_environment::Environment::ConstPtr env,
tesseract_environment::AdjacencyMap::ConstPtr adjacency_map,
const Eigen::Isometry3d& world_to_base,
SafetyMarginData::ConstPtr safety_margin_data,
tesseract_collision::ContactTestType contact_test_type,
double longest_valid_segment_length,
sco::VarVector vars0,
sco::VarVector vars1,
CollisionExpressionEvaluatorType type);
void CalcDistExpressions(const DblVec& x, sco::AffExprVector& exprs) override;
void CalcCollisions(const DblVec& x, tesseract_collision::ContactResultVector& dist_results) override;
void Plot(const tesseract_visualization::Visualization::Ptr& plotter, const DblVec& x) override;
sco::VarVector GetVars() override { return concat(vars0_, vars1_); }

private:
tesseract_collision::DiscreteContactManager::Ptr contact_manager_;
std::function<void(const DblVec&, sco::AffExprVector&)> fn_;
};

class TRAJOPT_API CollisionCost : public sco::Cost, public Plotter
Expand All @@ -127,7 +242,7 @@ class TRAJOPT_API CollisionCost : public sco::Cost, public Plotter
SafetyMarginData::ConstPtr safety_margin_data,
tesseract_collision::ContactTestType contact_test_type,
sco::VarVector vars);
/* constructor for cast cost */
/* constructor for discrete continuous and cast continuous cost */
CollisionCost(tesseract_kinematics::ForwardKinematics::ConstPtr manip,
tesseract_environment::Environment::ConstPtr env,
tesseract_environment::AdjacencyMap::ConstPtr adjacency_map,
Expand All @@ -136,7 +251,9 @@ class TRAJOPT_API CollisionCost : public sco::Cost, public Plotter
tesseract_collision::ContactTestType contact_test_type,
double longest_valid_segment_length,
sco::VarVector vars0,
sco::VarVector vars1);
sco::VarVector vars1,
CollisionExpressionEvaluatorType type,
bool discrete);
sco::ConvexObjective::Ptr convex(const DblVec& x, sco::Model* model) override;
double value(const DblVec&) override;
void Plot(const tesseract_visualization::Visualization::Ptr& plotter, const DblVec& x) override;
Expand All @@ -157,7 +274,7 @@ class TRAJOPT_API CollisionConstraint : public sco::IneqConstraint
SafetyMarginData::ConstPtr safety_margin_data,
tesseract_collision::ContactTestType contact_test_type,
sco::VarVector vars);
/* constructor for cast cost */
/* constructor for discrete continuous and cast continuous cost */
CollisionConstraint(tesseract_kinematics::ForwardKinematics::ConstPtr manip,
tesseract_environment::Environment::ConstPtr env,
tesseract_environment::AdjacencyMap::ConstPtr adjacency_map,
Expand All @@ -166,7 +283,9 @@ class TRAJOPT_API CollisionConstraint : public sco::IneqConstraint
tesseract_collision::ContactTestType contact_test_type,
double longest_valid_segment_length,
sco::VarVector vars0,
sco::VarVector vars1);
sco::VarVector vars1,
CollisionExpressionEvaluatorType type,
bool discrete);
sco::ConvexConstraints::Ptr convex(const DblVec& x, sco::Model* model) override;
DblVec value(const DblVec&) override;
void Plot(const DblVec& x);
Expand Down
17 changes: 15 additions & 2 deletions trajopt/include/trajopt/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ struct UserDefinedTermInfo : public TermInfo
/** @brief Timesteps over which to apply term */
int first_step, last_step;

/** @brief Indicated if a step is fixed and its variables cannot be changed */
std::vector<int> fixed_steps;

/** @brief The user defined error function */
sco::VectorOfVector::func error_function;

Expand Down Expand Up @@ -527,6 +530,13 @@ struct JointJerkTermInfo : public TermInfo
JointJerkTermInfo() : TermInfo(TT_COST | TT_CNT) {}
};

enum class CollisionEvaluatorType
{
SINGLE_TIMESTEP = 0,
DISCRETE_CONTINUOUS = 1,
CAST_CONTINUOUS = 2,
};

/**
\brief %Collision penalty

Expand All @@ -544,8 +554,11 @@ struct CollisionTermInfo : public TermInfo
/** @brief first_step and last_step are inclusive */
int first_step, last_step;

/** @brief Indicate if continuous collision checking should be used. */
bool continuous;
/** @brief Indicate the type of collision checking that should be used. */
CollisionEvaluatorType evaluator_type;

/** @brief Indicated if a step is fixed and its variables cannot be changed */
std::vector<int> fixed_steps;

/** @brief Set the resolution at which state validity needs to be verified in order for a motion between two states
* to be considered valid. If norm(state1 - state0) > longest_valid_segment_length.
Expand Down
Loading