Skip to content

Commit

Permalink
Adds SmoothView (#185)
Browse files Browse the repository at this point in the history
* almost works [skip ci]

* back up [skip ci]

* r2g

* Committing clang-format changes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
ryanmrichard and github-actions[bot] authored Oct 16, 2024
1 parent 68f1fd0 commit 4d36c58
Show file tree
Hide file tree
Showing 15 changed files with 1,043 additions and 6 deletions.
7 changes: 6 additions & 1 deletion include/tensorwrapper/detail_/polymorphic_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ class PolymorphicBase {
* @throw None No throw guarantee.
*/
bool are_equal(const_base_reference rhs) const noexcept {
// Downcast *this so it can be passed to are_equal_
const_base_reference plhs = static_cast<const_base_reference>(*this);
return are_equal_(rhs) && rhs.are_equal_(plhs);

// This line is necessary if are_equal_ is overriden in BaseType
const PolymorphicBase& rhs_upcast = rhs;

return are_equal_(rhs) && rhs_upcast.are_equal_(plhs);
}

/** @brief Determines if *this and @p rhs are polymorphically different.
Expand Down
61 changes: 61 additions & 0 deletions include/tensorwrapper/detail_/view_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <type_traits>

namespace tensorwrapper::detail_ {

/** @brief Is the cast from @p FromType to @p ToType just adding const?
*
* A common TMP pattern in implementing views is needing to convert mutable
* views to read-only views. This trait can be used to compare the template
* type parameters of two views (assuming the views are templated on what
* object they are acting like) in order to determine if they represent a
* conversion from @p FromType to @p ToType such that @p ToType is
* `const FromType`. If @p ToType is `const FromType` this template variable
* will be set to true, otherwise it will be set to false.
*
* @tparam FromType The type we are converting from.
* @tparam ToType The type we are converting to.
*/
template<typename FromType, typename ToType>
constexpr bool is_mutable_to_immutable_cast_v =
!std::is_const_v<FromType> && // FromType is NOT read-only
std::is_const_v<ToType> && // ToType is read-only
std::is_same_v<const FromType, ToType>; // They differ by const-ness

/** @brief Disables a templated function except when
* `is_mutable_to_immutable_cast_v<FromType, ToType>` evaluates to true.
*
* If `View` is a template class with template parameter type `T`, we want the
* implicit conversion from `View<T>` to `View<const T>` to exist. In practice,
* this leaves us with two options: partial specialization of `View` for
* const-qualified types or use of SFINAE to disable the conversion. We prefer
* the latter as the former requires us to duplicate the entirety of the
* class. This template type will disable the accompanying function via SFINAE
* if @p ToType is not `const FromType`.
*
* @tparam FromType The type we are converting from. Expected to be the
* template type parameter of the view we are casting from.
* @tparam ToType The type we are converting to. Expected to be the template
* type parameter of the view we are casting to.
*/
template<typename FromType, typename ToType>
using enable_if_mutable_to_immutable_cast_t =
std::enable_if_t<is_mutable_to_immutable_cast_v<FromType, ToType>>;

} // namespace tensorwrapper::detail_
57 changes: 52 additions & 5 deletions include/tensorwrapper/shape/shape_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <cstddef>
#include <memory>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/shape/shape_traits.hpp>
#include <tensorwrapper/shape/smooth_view.hpp>

namespace tensorwrapper::shape {

/** @brief Code factorization for the various types of shapes.
Expand All @@ -34,19 +37,29 @@ namespace tensorwrapper::shape {
* - get_rank_()
* - get_size_()
*/
class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
class ShapeBase : public tensorwrapper::detail_::PolymorphicBase<ShapeBase> {
private:
/// Type implementing the traits of this
using traits_type = ShapeTraits<ShapeBase>;

public:
/// Type all shapes inherit from
using shape_base = ShapeBase;
using shape_base = typename traits_type::shape_base;

/// Type of a pointer to the base of a shape object
using base_pointer = std::unique_ptr<shape_base>;
using base_pointer = typename traits_type::base_pointer;

/// Type used to hold the rank of a tensor
using rank_type = unsigned short;
using rank_type = typename traits_type::rank_type;

/// Type used to specify the number of elements in the shape
using size_type = std::size_t;
using size_type = typename traits_type::size_type;

/// Type of an object acting like a mutable reference to a Smooth shape
using smooth_reference = SmoothView<Smooth>;

/// Type of an object acting like a read-only reference to a Smooth shape
using const_smooth_reference = SmoothView<const Smooth>;

/// No-op for ShapeBase because ShapeBase has no state
ShapeBase() noexcept = default;
Expand Down Expand Up @@ -83,6 +96,34 @@ class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
*/
size_type size() const noexcept { return get_size_(); }

/** @brief Returns a view of *this as a Smooth object.
*
* It is possible to view any shape as a smooth shape. For more exotic
* shapes this may require flattening nestings and padding dimensions.
* This method ultimately dispatches to the as_smooth_ overload of the
* derived class to control how to smooth the shape out.
*
* @return A view of *this consistent with thinking of *this as a Smooth
* object.
*
* @throw std::bad_alloc if there is a problem allocating the view. Strong
* throw guarantee.
*/
smooth_reference as_smooth() { return as_smooth_(); }

/** @brief Returns a read-only view of *this as a Smooth object.
*
* This method works the same as the non-const version except that the
* resulting view is read-only.
*
* @return A read-only view of *this consistent with thinking of *this as
* a Smooth object.
*
* @throw std::bad_alloc if there is a problem allocating the view. Strong
* throw guarantee.
*/
const_smooth_reference as_smooth() const { return as_smooth_(); }

protected:
/** @brief Used to implement rank().
*
Expand All @@ -108,6 +149,12 @@ class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
* subject to a no-throw guarantee.
*/
virtual size_type get_size_() const noexcept = 0;

/// Derived class should override to be consistent with as_smooth()
virtual smooth_reference as_smooth_() = 0;

/// Derived class should override to be consistent with as_smooth() const
virtual const_smooth_reference as_smooth_() const = 0;
};

} // namespace tensorwrapper::shape
34 changes: 34 additions & 0 deletions include/tensorwrapper/shape/shape_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace tensorwrapper::shape {
namespace detail_ {

template<typename SmoothType>
class SmoothViewPIMPL;

}

class ShapeBase;

class Smooth;

template<typename SmoothType>
class SmoothView;

} // namespace tensorwrapper::shape
72 changes: 72 additions & 0 deletions include/tensorwrapper/shape/shape_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <memory>
#include <tensorwrapper/shape/shape_fwd.hpp>

namespace tensorwrapper::shape {

template<typename ShapeType>
struct ShapeTraits;

template<>
struct ShapeTraits<ShapeBase> {
using shape_base = ShapeBase;
using base_pointer = std::unique_ptr<shape_base>;
using rank_type = unsigned short;
using size_type = std::size_t;
};

template<>
struct ShapeTraits<const ShapeBase> {
using shape_base = ShapeBase;
using base_pointer = std::unique_ptr<shape_base>;
using rank_type = unsigned short;
using size_type = std::size_t;
};

template<>
struct ShapeTraits<Smooth> : public ShapeTraits<ShapeBase> {
using value_type = Smooth;
using const_value_type = const value_type;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
};

template<>
struct ShapeTraits<const Smooth> : public ShapeTraits<const ShapeBase> {
using value_type = Smooth;
using const_value_type = const value_type;
using reference = const value_type&;
using const_reference = const value_type&;
using pointer = const value_type*;
using const_pointer = const value_type*;
};

template<typename T>
struct ShapeTraits<SmoothView<T>> {
using smooth_traits = ShapeTraits<T>;
using pimpl_type = detail_::SmoothViewPIMPL<T>;
using const_pimpl_type =
detail_::SmoothViewPIMPL<typename smooth_traits::const_value_type>;
using pimpl_pointer = std::unique_ptr<pimpl_type>;
using const_pimpl_pointer = std::unique_ptr<const_pimpl_type>;
};

} // namespace tensorwrapper::shape
6 changes: 6 additions & 0 deletions include/tensorwrapper/shape/smooth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ class Smooth : public ShapeBase {
size_type(1), std::multiplies<size_type>());
}

smooth_reference as_smooth_() override { return smooth_reference(*this); }

virtual const_smooth_reference as_smooth_() const override {
return const_smooth_reference(*this);
}

/// Implements are_equal by calling ShapeBase::are_equal_impl_
bool are_equal_(const ShapeBase& rhs) const noexcept override {
return are_equal_impl_<Smooth>(rhs);
Expand Down
Loading

0 comments on commit 4d36c58

Please sign in to comment.