diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 9ee7470b0b2b3..d9c6068c0d1c4 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -53,6 +53,7 @@ option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" O option(onnxruntime_USE_MLAS "Use optimized blas library for GEMM and 2D Convolution" ON) option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF) option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF) +option(onnxruntime_USE_AUTOML "Build AutoML support" ON) option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF) option(onnxruntime_USE_OPENBLAS "Use openblas" OFF) option(onnxruntime_DEV_MODE "Enable developer warnings and treat most of them as error." OFF) @@ -646,6 +647,12 @@ include(onnxruntime_optimizer.cmake) include(onnxruntime_session.cmake) include(onnxruntime_mlas.cmake) +if(onnxruntime_USE_AUTOML) + add_definitions(-DMICROSOFT_AUTOML) + # Build shared featurizer library + include(onnxruntime_automl_featurizers.cmake) +endif() + if(WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES Shlwapi) list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp) diff --git a/cmake/onnxruntime_automl_featurizers.cmake b/cmake/onnxruntime_automl_featurizers.cmake new file mode 100644 index 0000000000000..daffe92842826 --- /dev/null +++ b/cmake/onnxruntime_automl_featurizers.cmake @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# This source code should not depend on the onnxruntime and may be built independently + +file(GLOB automl_featurizers_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/*.h" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/*.h" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/*.cpp" +) + +source_group(TREE ${ONNXRUNTIME_ROOT}/core/automl/ FILES ${onnxruntime_automl_featurizers_srcs}) + +add_library(automl_featurizers ${automl_featurizers_srcs}) + +target_include_directories(automl_featurizers PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +set_target_properties(automl_featurizers PROPERTIES FOLDER "AutoMLFeaturizers") + +# Individual featurizers unit tests added at bulk +file(GLOB automl_featurizers_tests_srcs + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/*.cpp" +) + +list(APPEND automl_featurizers_tests_srcs + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp" +) + +add_executable(automl_featurizers_unittests ${automl_featurizers_tests_srcs}) +add_dependencies(automl_featurizers_unittests automl_featurizers) +target_link_libraries(automl_featurizers_unittests PRIVATE gtest automl_featurizers) +source_group(TREE ${ONNXRUNTIME_ROOT}/core/automl/ FILES ${automl_featurizers_tests_srcs}) +set_target_properties(automl_featurizers_unittests PROPERTIES FOLDER "AutoMLFeaturizers") +add_test(NAME automl_featurizers_unittests + COMMAND automl_featurizers_unittests + WORKING_DIRECTORY $ +) + + +if (WIN32) + # Add Code Analysis properties to enable C++ Core checks. Have to do it via a props file include. + set_target_properties(automl_featurizers PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/ConfigureVisualStudioCodeAnalysis.props) +endif() diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 366eadf680fff..4c05a3307bff0 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -14,6 +14,13 @@ if (onnxruntime_DISABLE_CONTRIB_OPS) ) endif() +if(NOT onnxruntime_USE_AUTOML) + list(REMOVE_ITEM onnxruntime_graph_src + "${ONNXRUNTIME_ROOT}/core/graph/automl_ops/*.h" + "${ONNXRUNTIME_ROOT}/core/graph/automl_ops/*.cc" + ) +endif() + file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/defs/*.cc" ) @@ -21,6 +28,7 @@ file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS add_library(onnxruntime_graph ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) add_dependencies(onnxruntime_graph onnx_proto gsl) onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common gsl onnx onnx_proto protobuf::libprotobuf) + target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT}) set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index e072f830766d5..40478d7a8472d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -25,6 +25,16 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" ) +file(GLOB onnxruntime_cpu_automl_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/automl_ops/cpu_automl_kernels.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu_automl_kernels.cc" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_types.h" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_types.cc" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_featurizers.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu/*.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu/*.cc" +) + file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/*.h" "${ONNXRUNTIME_ROOT}/core/providers/*.cc" @@ -55,17 +65,30 @@ if(onnxruntime_USE_NNAPI) list(APPEND ONNXRUNTIME_PROVIDER_NAMES nnapi) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) -# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio -source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) + +set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) # disable contrib ops conditionally -if(onnxruntime_DISABLE_CONTRIB_OPS) - add_library(onnxruntime_providers ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) -else() - add_library(onnxruntime_providers ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs} ${onnxruntime_cpu_contrib_ops_srcs}) +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) +endif() + +if (onnxruntime_USE_AUTOML) + source_group(TREE ${ONNXRUNTIME_ROOT}/ FILES ${onnxruntime_cpu_automl_cc_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_automl_cc_srcs}) endif() +add_library(onnxruntime_providers ${onnxruntime_providers_src}) onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework gsl onnx onnx_proto protobuf::libprotobuf) + +if (onnxruntime_USE_AUTOML) + add_dependencies(onnxruntime_providers automl_featurizers) + onnxruntime_add_include_to_target(onnxruntime_providers automl_featurizers) + target_link_libraries(onnxruntime_providers automl_featurizers) +endif() + if(HAS_DEPRECATED_COPY) #temporarily ignore this warning #see: https://en.wikipedia.org/wiki/Rule_of_three_(C%2B%2B_programming) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 368ee8790d718..d63cd90f7e8f5 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -126,6 +126,12 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${TEST_SRC_DIR}/contrib_ops/*.cc") endif() +if(onnxruntime_USE_AUTOML) + list(APPEND onnxruntime_test_providers_src_patterns + "${TEST_SRC_DIR}/automl_ops/*.h" + "${TEST_SRC_DIR}/automl_ops/*.cc") +endif() + file(GLOB onnxruntime_test_providers_src CONFIGURE_DEPENDS ${onnxruntime_test_providers_src_patterns}) file(GLOB_RECURSE onnxruntime_test_providers_cpu_src CONFIGURE_DEPENDS @@ -209,6 +215,10 @@ if(onnxruntime_USE_NNAPI) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_AUTOML) + list(APPEND onnxruntime_test_providers_dependencies automl_featurizers) +endif() + file(GLOB_RECURSE onnxruntime_test_tvm_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/tvm/*.h" "${ONNXRUNTIME_ROOT}/test/tvm/*.cc" diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index e02027a328fcf..d396551a1b407 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -231,6 +231,11 @@ template KernelCreateInfo BuildKernelCreateInfo(); } // namespace contrib +namespace automl { +template +KernelCreateInfo BuildKernelCreateInfo(); +} // namespace automl + namespace contrib { namespace cuda { template diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 5872228f383d2..6a960e82a3074 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -19,6 +19,7 @@ constexpr const char* kOnnxDomainAlias = "ai.onnx"; constexpr const char* kMLDomain = "ai.onnx.ml"; constexpr const char* kMSDomain = "com.microsoft"; constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; +constexpr const char* kMSAutoMLDomain = "com.microsoft.automl"; constexpr const char* kNGraphDomain = "com.intel.ai"; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; diff --git a/onnxruntime/automl_ops/automl_featurizers.h b/onnxruntime/automl_ops/automl_featurizers.h new file mode 100644 index 0000000000000..37e6e982d9a62 --- /dev/null +++ b/onnxruntime/automl_ops/automl_featurizers.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Cumulative header with automl featurizers includes exposed to +// ORT +#pragma once + +#include "core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h" diff --git a/onnxruntime/automl_ops/automl_types.cc b/onnxruntime/automl_ops/automl_types.cc new file mode 100644 index 0000000000000..8f0cb77701606 --- /dev/null +++ b/onnxruntime/automl_ops/automl_types.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +#include "automl_ops/automl_types.h" +#include "automl_ops/automl_featurizers.h" + +namespace dtf = Microsoft::Featurizer::DateTimeFeaturizer; + +namespace onnxruntime { + +// This temporary to register custom types so ORT is aware of it +// although it still can not serialize such a type. +// These character arrays must be extern so the resulting instantiated template +// is globally unique + +extern const char kMsAutoMLDomain[] = "com.microsoft.automl"; + +extern const char kTimepointName[] = "DateTimeFeaturizer_TimePoint"; +// This has to be under onnxruntime to properly specialize a function template +ORT_REGISTER_OPAQUE_TYPE(dtf::TimePoint, kMsAutoMLDomain, kTimepointName); + +namespace automl { + +#define REGISTER_CUSTOM_PROTO(TYPE, reg_fn) \ + { \ + MLDataType mltype = DataTypeImpl::GetType(); \ + reg_fn(mltype); \ + } + +void RegisterAutoMLTypes(const std::function& reg_fn) { + REGISTER_CUSTOM_PROTO(dtf::TimePoint, reg_fn); +} +#undef REGISTER_CUSTOM_PROTO +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/automl_types.h b/onnxruntime/automl_ops/automl_types.h new file mode 100644 index 0000000000000..798c6778966bb --- /dev/null +++ b/onnxruntime/automl_ops/automl_types.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_types.h" +#include + +namespace onnxruntime { +namespace automl { +void RegisterAutoMLTypes(const std::function& reg_fn); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu/datetime_transformer.cc b/onnxruntime/automl_ops/cpu/datetime_transformer.cc new file mode 100644 index 0000000000000..08b23e57f1324 --- /dev/null +++ b/onnxruntime/automl_ops/cpu/datetime_transformer.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +#include "core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h" + +namespace dtf = Microsoft::Featurizer::DateTimeFeaturizer; + +namespace onnxruntime { +namespace automl { + +class DateTimeTransformer final : public OpKernel { + public: + explicit DateTimeTransformer(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext* context) const override; +}; + +Status DateTimeTransformer::Compute(OpKernelContext* ctx) const { + Status s; + auto input_tensor = ctx->Input(0); + dtf::TimePoint* output = ctx->Output(0); + + int64_t tp = *input_tensor->Data(); + std::chrono::system_clock::time_point sys_time{std::chrono::seconds(tp)}; + *output = std::move(dtf::SystemToDPTimePoint(sys_time)); + return s; +} + +ONNX_OPERATOR_KERNEL_EX( + DateTimeTransformer, + kMSAutoMLDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetType()), + DateTimeTransformer); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.cc b/onnxruntime/automl_ops/cpu_automl_kernels.cc new file mode 100644 index 0000000000000..23d5e2ad72e6a --- /dev/null +++ b/onnxruntime/automl_ops/cpu_automl_kernels.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "automl_ops/cpu_automl_kernels.h" +#include "core/graph/constants.h" +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace automl { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer); + +void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + // add more kernels here + BuildKernelCreateInfo + }; + + for (auto& function_table_entry : function_table) { + kernel_registry.Register(function_table_entry()); + } +} + +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.h b/onnxruntime/automl_ops/cpu_automl_kernels.h new file mode 100644 index 0000000000000..f14a8983d5a39 --- /dev/null +++ b/onnxruntime/automl_ops/cpu_automl_kernels.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace automl { +void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h new file mode 100644 index 0000000000000..54b737b645da9 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h @@ -0,0 +1,163 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include +#include + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \class Transformer +/// \brief Transforms a single "value" and output the result. +/// A value can be anything from an integer to a collection +/// of integers. +/// +template +class Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Types + using return_type = ReturnT; + using arg_type = ArgT; + using transformer_type = Transformer; + + // ---------------------------------------------------------------------- + // | Public Methods + Transformer(void) = default; + virtual ~Transformer(void) = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + virtual return_type transform(arg_type const &arg) const = 0; + +private: + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &, unsigned int const /*version*/); +}; + +///////////////////////////////////////////////////////////////////////// +/// \class Estimator +/// \brief Collects state over a collection of data, then produces +/// a `Transformer` that is able to operate on that collected +/// state. +/// +template +class Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Types + using transformer_type = Transformer; + using TransformerUniquePtr = std::unique_ptr; + + using estimator_type = Estimator; + + using apache_arrow = unsigned long; // TODO: Temp type as we figure out what will eventually be here + + // ---------------------------------------------------------------------- + // | Public Methods + Estimator(void) = default; + virtual ~Estimator(void) = default; + + Estimator(Estimator const &) = delete; + Estimator & operator =(Estimator const &) = delete; + + Estimator(Estimator &&) = default; + Estimator & operator =(Estimator &&) = delete; + + // This method can be called repeatedly in the support of streaming scenarios + Estimator & fit(apache_arrow const &data); + + // Calls to `commit` are destructive - all previously generated state should + // be reset. `Estimator` objects that want to share state prior to calls to commit + // should implement a `copy` method. + TransformerUniquePtr commit(void); + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool _committed = false; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &, unsigned int const /*version*/); + + virtual Estimator & fit_impl(apache_arrow const &data) = 0; + virtual TransformerUniquePtr commit_impl(void) = 0; +}; + +template +typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args); + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// | +// | Implementation +// | +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- + +// ---------------------------------------------------------------------- +// | +// | Transformer +// | +// ---------------------------------------------------------------------- +template +template +void Transformer::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) { +} + +// ---------------------------------------------------------------------- +// | +// | Estimator +// | +// ---------------------------------------------------------------------- +template +Estimator & Estimator::fit(apache_arrow const &data) { + if(_committed) + throw std::runtime_error("This instance has already been committed"); + + return fit_impl(data); +} + +template +typename Estimator::TransformerUniquePtr Estimator::commit(void) { + if(_committed) + throw std::runtime_error("This instance has already been committed"); + + TransformerUniquePtr result(commit_impl()); + + if(!result) + throw std::runtime_error("Invalid result"); + + _committed = true; + return result; +} + +template +template +void Estimator::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) { +} + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +template +typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args) { + return EstimatorT(std::forward(args)...).fit(data).commit(); +} + +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp new file mode 100644 index 0000000000000..56fc238d86aee --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp @@ -0,0 +1,56 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#include "DateTimeFeaturizer.h" + +#ifdef _MSC_VER +inline struct tm *gmtime_r(time_t const* const timer, struct tm* const result) { + return gmtime_s(result, timer) == 0 ? result : nullptr; +} + +#endif + +namespace Microsoft { +namespace Featurizer { + +namespace DateTimeFeaturizer { + + TimePoint::TimePoint(const std::chrono::system_clock::time_point& sysTime) { + // Get to a tm to get what we need. + // Eventually C++202x will have expanded chrono support that might + // have what we need, but not yet! + std::tm tmt; + time_t tt = std::chrono::system_clock::to_time_t(sysTime); + std::tm* res = gmtime_r(&tt, &tmt); + if (res) { + year = static_cast(tmt.tm_year) + 1900; + month = static_cast(tmt.tm_mon) + 1; + day = static_cast(tmt.tm_mday); + hour = static_cast(tmt.tm_hour); + minute = static_cast(tmt.tm_min); + second = static_cast(tmt.tm_sec); + dayOfWeek = static_cast(tmt.tm_wday); + dayOfYear = static_cast(tmt.tm_yday); + quarterOfYear = (month + 2) / 3; + weekOfMonth = (day - 1) / 7; + } + else + { + if (tt < 0) { + throw std::invalid_argument("Dates prior to 1970 are not supported."); + } + else { + throw std::invalid_argument("Unknown error converting input date."); + } + } + } + + Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ { + return Microsoft::Featurizer::DateTimeFeaturizer::TimePoint(arg); + } + + +} // namespace DateTimeFeaturizer +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h new file mode 100644 index 0000000000000..e1f98351db0b4 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h @@ -0,0 +1,101 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include "../Featurizer.h" +#include +#include +#include +#include + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \namespace DateTimeTransformer +/// \brief A Transformer that takes a chrono::system_clock::time_point and +/// returns a struct with all the data split out. +/// +namespace DateTimeFeaturizer { + + ///////////////////////////////////////////////////////////////////////// + /// \struct TimePoint + /// \brief Struct to hold various components of DateTime information + /// + struct TimePoint { + std::int32_t year = 0; + std::uint8_t month = 0; /* 1-12 */ + std::uint8_t day = 0; /* 1-31 */ + std::uint8_t hour = 0; /* 0-23 */ + std::uint8_t minute = 0; /* 0-59 */ + std::uint8_t second = 0; /* 0-59 */ + std::uint8_t dayOfWeek = 0; /* 0-6 */ + std::uint16_t dayOfYear = 0; /* 0-365 */ + std::uint8_t quarterOfYear = 0; /* 1-4 */ + std::uint8_t weekOfMonth = 0; /* 0-4 */ + + // Need default __ctor to satisfy ORT type system + TimePoint() = default; + TimePoint(const std::chrono::system_clock::time_point& sysTime); + + TimePoint(TimePoint&&) = default; + TimePoint& operator=(TimePoint&&) = default; + + TimePoint(const TimePoint&) = delete; + TimePoint& operator=(const TimePoint&) = delete; + + bool operator==(const TimePoint& o) const { + return year == o.year && + month == o.month && + day == o.day && + hour == o.hour && + minute == o.minute && + second == o.second && + dayOfWeek == o.dayOfWeek && + dayOfYear == o.dayOfYear && + quarterOfYear == o.quarterOfYear && + weekOfMonth == o.weekOfMonth; + } + + enum { + JANUARY = 1, FEBRUARY, MARCH, APRIL, MAY, JUNE, + JULY, AUGUST, SEPTEMBER, OCTOBER, NOVEMBER, DECEMBER + }; + enum { + SUNDAY = 0, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY + }; + }; + + inline TimePoint SystemToDPTimePoint(const std::chrono::system_clock::time_point& sysTime) { + return TimePoint (sysTime); + } + + ///////////////////////////////////////////////////////////////////////// + /// \class DateTimeTransformer + /// \brief Transformer + /// + class Transformer : public Microsoft::Featurizer::Transformer { + public: + Transformer(void) = default; + ~Transformer(void) override = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + return_type transform(arg_type const &arg) const override; + + private: + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); + }; + +} // Namespace DateTimeFeaturizer +} // Namespace Featurizer +} // Namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp new file mode 100644 index 0000000000000..b474ce3bd8a62 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp @@ -0,0 +1,40 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#include "SampleAdd.h" + +namespace Microsoft { +namespace Featurizer { +namespace SampleAdd { + +// ---------------------------------------------------------------------- +// | +// | Transformer +// | +// ---------------------------------------------------------------------- +Transformer::Transformer(std::uint16_t delta) : + _delta(delta) { +} + +Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ { + return _delta + arg; +} + +// ---------------------------------------------------------------------- +// | +// | Estimator +// | +// ---------------------------------------------------------------------- +Estimator & Estimator::fit_impl(apache_arrow const &data) /*override*/ { + _accumulated_delta += static_cast(data); + return *this; +} + +Estimator::TransformerUniquePtr Estimator::commit_impl(void) /*override*/ { + return std::make_unique(static_cast(_accumulated_delta)); +} + +} // namespace SampleAdd +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h new file mode 100644 index 0000000000000..5c333fd7e498b --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h @@ -0,0 +1,118 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include "../Featurizer.h" + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \namespace SampleAdd +/// \brief A Transformer and Estimator that add values. This is a +/// sample intended to demonstrate patterns within the +/// implementation of these types. +/// +namespace SampleAdd { + +///////////////////////////////////////////////////////////////////////// +/// \class Transformer +/// \brief Transformer that adds an integer value to a saved delta +/// and returns the result. +/// +class Transformer : public Microsoft::Featurizer::Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Methods + Transformer(std::uint16_t delta=0); + ~Transformer(void) override = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + return_type transform(arg_type const &arg) const override; + +private: + // ---------------------------------------------------------------------- + // | Private Data + std::uint32_t const _delta; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); +}; + +///////////////////////////////////////////////////////////////////////// +/// \class Estimator +/// \brief Estimator that accumulates a delta value and then +/// creates a Transformer with than value when requested. +/// +class Estimator : public Microsoft::Featurizer::Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Methods + Estimator(void) = default; + ~Estimator(void) override = default; + + Estimator(Estimator const &) = delete; + Estimator & operator =(Estimator const &) = delete; + + Estimator(Estimator &&) = default; + Estimator & operator =(Estimator &&) = delete; + +private: + // ---------------------------------------------------------------------- + // | Private Data + std::uint32_t _accumulated_delta = 0; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); + + Estimator & fit_impl(apache_arrow const &data) override; + TransformerUniquePtr commit_impl(void) override; +}; + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// | +// | Implementation +// | +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- + +// ---------------------------------------------------------------------- +// | +// | Transformer +// | +// ---------------------------------------------------------------------- +template +void Transformer::serialize(ArchiveT &ar, unsigned int const version) { + ar & boost::serialization::base_object(*this); + ar & boost::serialization::make_nvp("delta", _delta); +} + +// ---------------------------------------------------------------------- +// | +// | Estimator +// | +// ---------------------------------------------------------------------- +template +void Estimator::serialize(ArchiveT &ar, unsigned int const version) { + ar & boost::serialization::base_object(*this); + ar & boost::serialization::make_nvp("accumulated_delta", _accumulated_delta); +} + +} // namespace SampleAdd + +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt new file mode 100644 index 0000000000000..acbc320062979 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License +# ---------------------------------------------------------------------- +cmake_minimum_required(VERSION 3.5.0) + +project(Featurizer_UnitTests LANGUAGES CXX) + +set(CMAKE_MODULE_PATH "$ENV{DEVELOPMENT_ENVIRONMENT_CMAKE_MODULE_PATH}") + +if(NOT WIN32) + string(REPLACE ":" ";" CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + string(REPLACE ":" ";" _includes "$ENV{INCLUDE}") + string(REPLACE ":" ";" _libs "$ENV{LIB}") +endif() + +set(CppCommon_STATIC_CRT ON CACHE BOOL "" FORCE) +set(BoostCommon_HEADER_ONLY ON CACHE BOOL "" FORCE) + +include(CppCommon) +include(BoostCommon) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +add_library(libFeaturizers STATIC + ../SampleAdd.h + ../SampleAdd.cpp + ../DateTimeFeaturizer.h + ../DateTimeFeaturizer.cpp +) + +enable_testing() + +foreach(_test_name IN ITEMS + SampleAdd_UnitTest + DateTimeFeaturizer_UnitTests +) + add_executable(${_test_name} ${_test_name}.cpp) + + target_include_directories(${_test_name} PRIVATE ${_includes}) + target_link_directories(${_test_name} PRIVATE ${_libs}) + + target_link_libraries(${_test_name} PRIVATE ${Boost_LIBRARIES} libFeaturizers) + + add_test(NAME ${_test_name} COMMAND ${_test_name} --success) +endforeach() diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp new file mode 100644 index 0000000000000..a1c6ac2f29fe4 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp @@ -0,0 +1,125 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include +#include "gtest/gtest.h" + +#include "../DateTimeFeaturizer.h" + + +namespace Microsoft { +namespace Featurizer { +namespace DateTimeFeaturizer { + +using SysClock = std::chrono::system_clock; + +TEST(DateTimeFeaturizer_DateTime, Past_1976_Nov_17__12_27_04) { + const time_t date = 217081624; + SysClock::time_point stp = SysClock::from_time_t(date); + + // Constructor + TimePoint tp(stp); + ASSERT_TRUE(tp.year == 1976); + ASSERT_TRUE(tp.month == TimePoint::NOVEMBER); + ASSERT_TRUE(tp.day == 17); + ASSERT_TRUE(tp.hour == 12); + ASSERT_TRUE(tp.minute == 27); + ASSERT_TRUE(tp.second == 4); + ASSERT_TRUE(tp.dayOfWeek == TimePoint::WEDNESDAY); + ASSERT_TRUE(tp.dayOfYear == 321); + ASSERT_TRUE(tp.quarterOfYear == 4); + ASSERT_TRUE(tp.weekOfMonth == 2); + + // assignment + TimePoint tp1 = stp; + ASSERT_TRUE(tp1.year == 1976); + ASSERT_TRUE(tp1.month == TimePoint::NOVEMBER); + ASSERT_TRUE(tp1.day == 17); + + // function + TimePoint tp2 = SystemToDPTimePoint(stp); + ASSERT_TRUE(tp2.year == 1976); + ASSERT_TRUE(tp2.month == TimePoint::NOVEMBER); + ASSERT_TRUE(tp2.day == 17); +} + +TEST(DateTimeFeaturizer_Transformer , Past_1976_Nov_17__12_27_05) { + const time_t date = 217081625; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_TRUE(tp.year == 1976); + ASSERT_TRUE(tp.month == TimePoint::NOVEMBER); + ASSERT_TRUE(tp.day == 17); + ASSERT_TRUE(tp.hour == 12); + ASSERT_TRUE(tp.minute == 27); + ASSERT_TRUE(tp.second == 5); + ASSERT_TRUE(tp.dayOfWeek == TimePoint::WEDNESDAY); + ASSERT_TRUE(tp.dayOfYear == 321); + ASSERT_TRUE(tp.quarterOfYear == 4); + ASSERT_TRUE(tp.weekOfMonth == 2); + +} + +TEST(DateTimeFeaturizer_Transformer , Future_2025_June_30) { + const time_t date = 1751241600; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_TRUE(tp.year == 2025); + ASSERT_TRUE(tp.month == TimePoint::JUNE); + ASSERT_TRUE(tp.day == 30); + ASSERT_TRUE(tp.hour == 0); + ASSERT_TRUE(tp.minute == 0); + ASSERT_TRUE(tp.second == 0); + ASSERT_TRUE(tp.dayOfWeek == TimePoint::MONDAY); + ASSERT_TRUE(tp.dayOfYear == 180); + ASSERT_TRUE(tp.quarterOfYear == 2); + ASSERT_TRUE(tp.weekOfMonth == 4); +} + +#ifdef _MSC_VER +// others define system_clock::time_point as nanoseconds (64-bit), +// which rolls over somewhere around 2260. Still a couple hundred years! +TEST(DateTimeFeaturizer_Transformer , Far_Future__2998_March_2__14_03_02) { + const time_t date = 32445842582; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_TRUE(tp.year == 2998); + ASSERT_TRUE(tp.month == TimePoint::MARCH); + ASSERT_TRUE(tp.day == 2); + ASSERT_TRUE(tp.hour == 14); + ASSERT_TRUE(tp.minute == 3); + ASSERT_TRUE(tp.second == 2); + ASSERT_TRUE(tp.dayOfWeek == TimePoint::FRIDAY); + ASSERT_TRUE(tp.dayOfYear == 60); + ASSERT_TRUE(tp.quarterOfYear == 1); + ASSERT_TRUE(tp.weekOfMonth == 0); +} + +#else + +// msvcrt doesn't support negative time_t, so nothing before 1970 +TEST(DateTimeFeaturizer_Transformer, Pre_Epoch__1776_July_4) { + + const time_t date = -6106060800; + SysClock::time_point stp = SysClock::from_time_t(date); + + // Constructor + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_TRUE(tp.year == 1776); + ASSERT_TRUE(tp.month == TimePoint::JULY); + ASSERT_TRUE(tp.day == 4); +} +#endif /* _MSC_VER */ +} // namespace DateTimeFeaturizer +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp new file mode 100644 index 0000000000000..40e7336b8ae7e --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp @@ -0,0 +1,22 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include "gtest/gtest.h" + +#include "../SampleAdd.h" + +TEST(SampleAddTests, Transformer) { + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Transformer(10).transform(20) == 30); + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Transformer(20).transform(1) == 21); +} + +TEST(SampleAddTests, Estimator) { + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).commit()->transform(20) == 30); + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Estimator().fit(20).commit()->transform(1) == 21); + + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).commit()->transform(20) == 50); + ASSERT_TRUE(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).fit(30).commit()->transform(20) == 80); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml new file mode 100644 index 0000000000000..e3f068978a9bd --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml @@ -0,0 +1,5 @@ +filter: + includes: + - Microsoft::Featurizer::* + excludes: + - std::* diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h new file mode 100644 index 0000000000000..f097956403cd9 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h @@ -0,0 +1,217 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#pragma once +#include +#include +#include +#include + +namespace Microsoft { +namespace Featurizer { +namespace Traits { + +// XXX: Define the type +template +struct Nullable {}; + +///////////////////////////////////////////////////////////////////////// +/// \namespace Traits +/// \brief We have a range of of types we are dealing with. Many types +/// have different ways to represent what a `NULL` value is +/// (float has NAN for example) as well as different ways to +/// convert the value to a string representation. By using +/// templates combined with partial template specialization +/// we can handle scenarios like these that vary based on the data type. +/// +/// Example: This allows us to do things like `Traits::IsNull()` +/// and `Traits::IsNull()` and let the trait itself deal with the +/// actual implementation and allows us as developers to not worry about that. +/// +/// This benefit is magnified because we are also using templates for our +/// transformers. When we declare that a transformer has type T = std::int8_t, +/// we can then also use `Traits::IsNull()` and the compiler will know that +/// `T` is a `std::int8_t` and call the appropate template specialization. +/// +template +struct Traits {}; + +///////////////////////////////////////////////////////////////////////// +/// \namespace Traits +/// \brief When using partial template specilization, if the compiler +/// cannot find a more specfic implementation of the template +/// it will fall back to the base template and use whatever is +/// defined there. If you have methods defined in that base template, +/// it makes it very difficult to debug what is going on. By +/// putting no implementation in the `Traits<>` template and +/// having the real base struct be `TraitsImpl<>`, if you try and +/// specify a trait that doesn't have a specilization, the compiler +/// can detect that and throw an error during compilation. +/// +/// Example: There is no template `Traits`. If you try and use it +/// the compiler will fall back to the `Traits<>` struct which has no methods +/// defined. Trying to then use `Traits` will cause a compile time error +/// letting you know something isn't correct. +/// +template +struct TraitsImpl { + using nullable_type = Nullable; + static bool IsNull(nullable_type const& value) { + return !value.is_initialized(); + } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = float; + static bool IsNull(nullable_type const& value) { + return isnan(value); + } + + // static std::string ToString(nullable_type const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = double; + static bool IsNull(nullable_type const& value) { + return isnan(value); + } + + // static std::string ToString(nullable_type const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int8_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int16_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int32_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int64_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint8_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = Nullable; + // static std::string ToString(std::uint16_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint32_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint64_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::string const& value) { + // value; + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::array const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(bool const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::map const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::vector const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::function const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + using nullable_type = Nullable; + + // static std::string ToString(nullable_type const& value) { + // if (value) { + // return Traits::ToString(value.get()); + // } + + // return "NULL"; + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::tuple const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +} // namespace Traits +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt new file mode 100644 index 0000000000000..024c76f3443a7 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt @@ -0,0 +1,41 @@ +# ---------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License +# ---------------------------------------------------------------------- +cmake_minimum_required(VERSION 3.5.0) + +project(Featurizer_UnitTests LANGUAGES CXX) + +set(CMAKE_MODULE_PATH "$ENV{DEVELOPMENT_ENVIRONMENT_CMAKE_MODULE_PATH}") + +if(NOT WIN32) + string(REPLACE ":" ";" CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + string(REPLACE ":" ";" _includes "$ENV{INCLUDE}") + string(REPLACE ":" ";" _libs "$ENV{LIB}") +endif() + +set(CppCommon_STATIC_CRT ON CACHE BOOL "" FORCE) +set(BoostCommon_HEADER_ONLY ON CACHE BOOL "" FORCE) + +include(CppCommon) +include(BoostCommon) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +enable_testing() + +foreach(_test_name IN ITEMS + Featurizer_UnitTest + Traits_UnitTests +) + add_executable(${_test_name} ${_test_name}.cpp) + + target_include_directories(${_test_name} PRIVATE ${_includes}) + target_link_directories(${_test_name} PRIVATE ${_libs}) + + target_link_libraries(${_test_name} PRIVATE ${Boost_LIBRARIES}) + + add_test(NAME ${_test_name} COMMAND ${_test_name} --success) +endforeach() diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp new file mode 100644 index 0000000000000..caed233f48419 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp @@ -0,0 +1,119 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include "gtest/gtest.h" +#include "../Featurizer.h" + +class MyTransformer : public Microsoft::Featurizer::Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Methods + MyTransformer(bool true_on_odd=false) : + _true_on_odd(true_on_odd) { + } + + ~MyTransformer(void) override = default; + + MyTransformer(MyTransformer const &) = delete; + MyTransformer & operator =(MyTransformer const &) = delete; + + MyTransformer(MyTransformer &&) = default; + MyTransformer & operator =(MyTransformer &&) = delete; + + return_type transform(arg_type const &arg) const override { + bool const is_odd(arg & 1); + + return _true_on_odd ? is_odd : !is_odd; + } + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool const _true_on_odd; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const /*version*/) { + ar & boost::serialization::base_object(*this); + ar & boost::serialization::make_nvp("true_on_odd", const_cast(_true_on_odd)); + } +}; + +class MyEstimator : public Microsoft::Featurizer::Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Methods + MyEstimator(bool return_invalid_transformer=false) : + _return_invalid_transformer(return_invalid_transformer) { + } + + ~MyEstimator(void) override = default; + + MyEstimator(MyEstimator const &) = delete; + MyEstimator & operator =(MyEstimator const &) = delete; + + MyEstimator(MyEstimator &&) = default; + MyEstimator & operator =(MyEstimator &&) = delete; + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool const _return_invalid_transformer; + bool _true_on_odd_state; + + // ---------------------------------------------------------------------- + // | Private Methods + MyEstimator & fit_impl(apache_arrow const &data) override { + _true_on_odd_state = static_cast(data); + return *this; + } + + TransformerUniquePtr commit_impl(void) override { + if(_return_invalid_transformer) + return TransformerUniquePtr(); + + return std::make_unique(_true_on_odd_state); + } + + template + void serialize(ArchiveT &ar, unsigned int const /*version*/) { + ar & boost::serialization::base_object(*this); + ar & boost::serialization::make_nvp("return_invalid_transformer", const_cast(_return_invalid_transformer)); + ar & boost::serialization::make_nvp("true_on_odd_state", const_cast(_true_on_odd_state)); + } +}; + +TEST(FeaturizerTests, TransformerFunctionality) { + ASSERT_TRUE(MyTransformer(true).transform(1) == true); + ASSERT_TRUE(MyTransformer(false).transform(1) == false); + ASSERT_TRUE(MyTransformer(true).transform(2) == false); + ASSERT_TRUE(MyTransformer(false).transform(2) == true); +} + +TEST(FeaturizerTests, EstimatorFunctionality) { + ASSERT_TRUE(MyEstimator().fit(1).commit()->transform(1) == true); + ASSERT_TRUE(MyEstimator().fit(0).commit()->transform(1) == false); + ASSERT_TRUE(MyEstimator().fit(1).commit()->transform(2) == false); + ASSERT_TRUE(MyEstimator().fit(0).commit()->transform(2) == true); +} + +TEST(FeaturizerTests, EstimatorErrors) { + MyEstimator e; + + ASSERT_NE(e.commit(), nullptr); + //CHECK_THROWS_WITH(e.fit(1), Catch::Contains("has already been committed")); + //CHECK_THROWS_WITH(e.commit(), Catch::Contains("has already been committed")); + + //CHECK_THROWS_WITH(MyEstimator(true).commit(), Catch::Matches("Invalid result")); +} + +TEST(FeaturizerTests, EstimatorFitAndCommit) { + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(1, false)->transform(1) == true); + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(0, false)->transform(1) == false); + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(1, false)->transform(2) == false); + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(0, false)->transform(2) == true); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp new file mode 100644 index 0000000000000..66589a5c9decc --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp @@ -0,0 +1,40 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#define CATCH_CONFIG_MAIN +#include +#include "gtest/gtest.h" + +#include "../Traits.h" + +using namespace Microsoft::Featurizer::Traits; + +// Floating point values +static_assert(std::is_same::nullable_type, float>::value, "Incorrect nullable type for float"); +static_assert(std::is_same::nullable_type, double>::value, "Incorrect nullable type for double"); + +// Int values +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int8_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int16_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int32_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int64_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint8_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint16_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint32_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint64_t"); + +// Others +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::array"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); + +// Dummy test so it will compile. Replace this with actual tests. +TEST(TraitsTests, Dummy) { + ASSERT_TRUE(true); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp new file mode 100644 index 0000000000000..b6a004002b83c --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" + +GTEST_API_ int main(int argc, char** argv) { + int status = 0; + + testing::InitGoogleTest(&argc, argv); + try { + status = RUN_ALL_TESTS(); + } catch (const std::exception& ex) { + std::cerr << ex.what(); + status = -1; + } + + return status; +} diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index b41a59518cfaf..410a7b5628c93 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -6,6 +6,10 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" +#ifdef MICROSOFT_AUTOML +#include "automl_ops/automl_types.h" +#endif + #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" @@ -285,6 +289,9 @@ class DataTypeRegistry { DataTypeRegistry() { RegisterAllProtos([this](MLDataType mltype) { RegisterDataType(mltype); }); +#ifdef MICROSOFT_AUTOML + automl::RegisterAutoMLTypes([this](MLDataType mltype) { RegisterDataType(mltype); }); +#endif } ~DataTypeRegistry() = default; diff --git a/onnxruntime/core/graph/automl_ops/automl_defs.cc b/onnxruntime/core/graph/automl_ops/automl_defs.cc new file mode 100644 index 0000000000000..dc4dd653f37c0 --- /dev/null +++ b/onnxruntime/core/graph/automl_ops/automl_defs.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/constants.h" +#include "core/graph/automl_ops/automl_defs.h" +#include "core/graph/op.h" +#include "onnx/defs/schema.h" +#include "onnx/defs/shape_inference.h" + +namespace onnxruntime { +namespace automl { +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL; + +void RegisterAutoMLSchemas() { + + static const char* DateTimeTransformer_ver1_doc = R"DOC( + DateTimeTransformer accepts a single scalar int64 tensor, constructs + an instance of std::chrono::system_clock::time_point and passes it as an argument + to Microsoft::DateTimeFeaturizer which is a part of a shared library. + It returns an instance of TimePoint class. + )DOC"; + + MS_AUTOML_OPERATOR_SCHEMA(DateTimeTransformer) + .SinceVersion(1) + .SetDomain(kMSAutoMLDomain) + .SetDoc(DateTimeTransformer_ver1_doc) + .Input(0, "X", + "The input represents a number of seconds passed since the epoch, suitable to properly construct" + "an instance of std::chrono::system_clock::time_point", + "T1") + .Output(0, "Y", "The output which is a Microsoft::DateTimeFeaturizer::TimePoint structure", "T2") + .TypeConstraint( + "T1", + {"tensor(int64)"}, + "Constrain input type to int64 scalar tensor.") + .TypeConstraint( + "T2", + {"opaque(com.microsoft.automl,DateTimeFeaturizer_TimePoint)"}, + "Constrain output type to an AutoML specific Microsoft::Featurizers::TimePoint type" + "currently not part of ONNX standard. When it becomes a part of the standard we will adjust this" + "kernel definition and move it to ONNX repo"); +} +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/automl_ops/automl_defs.h b/onnxruntime/core/graph/automl_ops/automl_defs.h new file mode 100644 index 0000000000000..b1a37366c396d --- /dev/null +++ b/onnxruntime/core/graph/automl_ops/automl_defs.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { +namespace automl { +#define MS_AUTOML_OPERATOR_SCHEMA(name) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ(Counter, name) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ(Counter, name) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__) + +#define MS_AUTOML_OPERATOR_SCHEMA_ELSEWHERE(name, schema_func) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)) + +void RegisterAutoMLSchemas(); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index bb95212533216..246e8af3d93e9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -9,6 +9,10 @@ #include "contrib_ops/cpu_contrib_kernels.h" #endif +#ifdef MICROSOFT_AUTOML +#include "automl_ops/cpu_automl_kernels.h" +#endif + #include "core/framework/compute_capability.h" namespace onnxruntime { @@ -696,6 +700,9 @@ static void RegisterCPUKernels(KernelRegistry& kernel_registry) { #ifndef DISABLE_CONTRIB_OPS ::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry); #endif +#ifdef MICROSOFT_AUTOML + ::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry); +#endif } std::shared_ptr GetCpuKernelRegistry() { diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index ef253bbb61ad2..d1f9041c9253f 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -10,6 +10,9 @@ #ifndef DISABLE_CONTRIB_OPS #include "core/graph/contrib_ops/contrib_defs.h" #endif +#ifdef MICROSOFT_AUTOML +#include "core/graph/automl_ops/automl_defs.h" +#endif namespace onnxruntime { using namespace ::onnxruntime::common; @@ -33,10 +36,14 @@ Status Environment::Initialize() { std::call_once(schemaRegistrationOnceFlag, []() { ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDomain, 1, 1); ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSNchwcDomain, 1, 1); + ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSAutoMLDomain, 1, 1); // Register contributed schemas. // The corresponding kernels are registered inside the appropriate execution provider. #ifndef DISABLE_CONTRIB_OPS contrib::RegisterContribSchemas(); +#endif +#ifdef MICROSOFT_AUTOML + automl::RegisterAutoMLSchemas(); #endif RegisterOnnxOperatorSetSchema(); RegisterOnnxMLOperatorSetSchema(); diff --git a/onnxruntime/test/automl_ops/datetimetransformer_test.cc b/onnxruntime/test/automl_ops/datetimetransformer_test.cc new file mode 100644 index 0000000000000..d0d82a1fbc39b --- /dev/null +++ b/onnxruntime/test/automl_ops/datetimetransformer_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include "core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h" + +namespace dft = Microsoft::Featurizer::DateTimeFeaturizer; + +using SysClock = std::chrono::system_clock; + +namespace onnxruntime { +namespace test { + +TEST(DateTimeFeaturizer_DateTime, Past_1976_Nov_17__12_27_04) { + + const time_t date = 217081624; + OpTester test("DateTimeTransformer", 1, onnxruntime::kMSAutoMLDomain); + + // We are adding a scalar Tensor in this instance + test.AddInput("X", {1}, {date}); + + SysClock::time_point stp = SysClock::from_time_t(date); + dft::TimePoint tp(stp); + ASSERT_TRUE(tp.year == 1976); + ASSERT_TRUE(tp.month == dft::TimePoint::NOVEMBER); + ASSERT_TRUE(tp.day == 17); + ASSERT_TRUE(tp.hour == 12); + ASSERT_TRUE(tp.minute == 27); + ASSERT_TRUE(tp.second == 4); + ASSERT_TRUE(tp.dayOfWeek == dft::TimePoint::WEDNESDAY); + ASSERT_TRUE(tp.dayOfYear == 321); + ASSERT_TRUE(tp.quarterOfYear == 4); + ASSERT_TRUE(tp.weekOfMonth == 2); + + // Expected output. + test.AddOutput("Y", std::move(tp)); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(DateTimeFeaturizer_Transformer, Past_1976_Nov_17__12_27_05) { + const time_t date = 32445842582; + + OpTester test("DateTimeTransformer", 1, onnxruntime::kMSAutoMLDomain); + // We are adding a scalar Tensor in this instance + test.AddInput("X", {1}, {date}); + + SysClock::time_point stp = SysClock::from_time_t(date); + + dft::Transformer dt; + dft::TimePoint tp = dt.transform(stp); + ASSERT_TRUE(tp.year == 2998); + ASSERT_TRUE(tp.month == dft::TimePoint::MARCH); + ASSERT_TRUE(tp.day == 2); + ASSERT_TRUE(tp.hour == 14); + ASSERT_TRUE(tp.minute == 3); + ASSERT_TRUE(tp.second == 2); + ASSERT_TRUE(tp.dayOfWeek == dft::TimePoint::FRIDAY); + ASSERT_TRUE(tp.dayOfYear == 60); + ASSERT_TRUE(tp.quarterOfYear == 1); + ASSERT_TRUE(tp.weekOfMonth == 0); + + // Expected output. + test.AddOutput("Y", std::move(tp)); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 4abbc94b827a9..8a5439e29ef70 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -14,6 +14,11 @@ #include "core/session/inference_session.h" #include "test/util/include/default_providers.h" +#ifdef MICROSOFT_AUTOML +#include "automl_ops/automl_featurizers.h" +namespace dtf = Microsoft::Featurizer::DateTimeFeaturizer; +#endif + using namespace ::onnxruntime::logging; namespace onnxruntime { @@ -129,6 +134,30 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_ } } +template <> +void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) { + auto& expected_tensor = expected_data.data_.Get(); + auto* expected = expected_tensor.template Data(); + auto* output = output_tensor.template Data(); + auto size = output_tensor.Shape().Size(); + + std::vector f_expected(size); + std::vector f_output(size); + BFloat16ToFloat(expected, f_expected.data(), static_cast(size)); + BFloat16ToFloat(output, f_output.data(), static_cast(size)); + + /// XXX: May need to adjust threshold as BFloat is coarse + float threshold = 0.001f; + for (int i = 0; i < size; ++i) { + if (std::isinf(f_expected[i])) // Test infinity for equality + EXPECT_EQ(f_expected[i], f_output[i]); + else { + // the default for existing tests + EXPECT_NEAR(f_expected[i], f_output[i], threshold) << "provider_type: " << provider_type; + } + } +} + template void CheckDispatch(MLDataType type, const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) { if (type == DataTypeImpl::GetType()) @@ -180,8 +209,13 @@ void CheckDispatch(MLDataType type, const OpTester::Data& expected_data, OrtValu } void Check(const OpTester::Data& expected_data, OrtValue& ort_value, const std::string& provider_type) { +#ifdef MICROSOFT_AUTOML + CheckDispatch(expected_data.data_.Type(), expected_data, ort_value, + provider_type); +#else CheckDispatch(expected_data.data_.Type(), expected_data, ort_value, provider_type); +#endif } void DebugTrap() { diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 5f93eabbcc714..f7343b7cb924e 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -176,6 +176,30 @@ class OpTester { AddData(input_data_, name, dims, values.data(), values.size(), is_initializer); } + // Add other registered types, possibly experimental + template + void AddInput(const char* name, const T& val) { + auto mltype = DataTypeImpl::GetType(); + ORT_ENFORCE(mltype != nullptr, "T must be a registered cpp type"); + auto ptr = std::make_unique(val); + OrtValue value; + value.Init(ptr.get(), mltype, mltype->GetDeleteFunc()); + ptr.release(); + input_data_.push_back({{name, mltype->GetTypeProto()}, value, optional(), optional()}); + } + + template + void AddInput(const char* name, T&& val) { + auto mltype = DataTypeImpl::GetType(); + ORT_ENFORCE(mltype != nullptr, "T must be a registered cpp type"); + auto ptr = std::make_unique(std::move(val)); + OrtValue value; + value.Init(ptr.get(), mltype, mltype->GetDeleteFunc()); + ptr.release(); + input_data_.push_back({{name, mltype->GetTypeProto()}, value, optional(), optional()}); + } + + template void AddInput(const char* name, const std::map& val) { std::unique_ptr> ptr = std::make_unique>(val); @@ -208,6 +232,29 @@ class OpTester { output_data_.push_back({{name, &s_type_proto}, {}, optional(), optional()}); } + // Add other registered types, possibly experimental + template + void AddOutput(const char* name, const T& val) { + auto mltype = DataTypeImpl::GetType(); + ORT_ENFORCE(mltype != nullptr, "T must be a registered cpp type"); + auto ptr = std::make_unique(val); + OrtValue value; + value.Init(ptr.get(), mltype, mltype->GetDeleteFunc()); + ptr.release(); + output_data_.push_back({{name, mltype->GetTypeProto()}, value, optional(), optional()}); + } + + template + void AddOutput(const char* name, T&& val) { + auto mltype = DataTypeImpl::GetType(); + ORT_ENFORCE(mltype != nullptr, "T must be a registered cpp type"); + auto ptr = std::make_unique(std::move(val)); + OrtValue value; + value.Init(ptr.get(), mltype, mltype->GetDeleteFunc()); + ptr.release(); + output_data_.push_back({{name, mltype->GetTypeProto()}, value, optional(), optional()}); + } + // Add non tensor output template void AddOutput(const char* name, const std::vector>& val) { diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index bbaa891ef282b..2707f9d863823 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -127,6 +127,7 @@ def parse_arguments(): parser.add_argument("--use_openblas", action='store_true', help="Build with OpenBLAS.") parser.add_argument("--use_mkldnn", action='store_true', help="Build with MKLDNN.") parser.add_argument("--use_mklml", action='store_true', help="Build with MKLML.") + parser.add_argument("--use_automl", action='store_true', help="Build with AutoML support.") parser.add_argument("--use_ngraph", action='store_true', help="Build with nGraph.") parser.add_argument("--use_openvino", nargs="?", const="CPU_FP32", choices=["CPU_FP32","GPU_FP32","GPU_FP16","VAD-M_FP16","MYRIAD_FP16"], help="Build with OpenVINO for specific hardware.") @@ -323,6 +324,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_USE_CUDA=" + ("ON" if args.use_cuda else "OFF"), "-Donnxruntime_USE_NSYNC=" + ("OFF" if is_windows() or not args.use_nsync else "ON"), "-Donnxruntime_CUDNN_HOME=" + (cudnn_home if args.use_cuda else ""), + "-Donnxruntime_USE_AUTOML=" + ("ON" if args.use_automl else "OFF"), "-Donnxruntime_CUDA_HOME=" + (cuda_home if args.use_cuda else ""), "-Donnxruntime_USE_JEMALLOC=" + ("ON" if args.use_jemalloc else "OFF"), "-Donnxruntime_ENABLE_PYTHON=" + ("ON" if args.enable_pybind else "OFF"),