From 16cb9833395a4904ab6d5e37ea538caaaf70e3dd Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Thu, 5 Sep 2019 11:54:21 -0700 Subject: [PATCH] optimize quantize (#1762) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add more type support for OneHot op (#1565) * parallel build * update quatizelinear to process int8 input (#1576) * Remove unneeded C APIs + some refactoring. (#1555) * Mention OrtCreateSessionFromArray in C API doc * c api changes after review (1) * updates... * fixes * Reorder include * A few performance improvements coming out of ssd_mobilenet and ssd_resnet34 analysis (#1578) * A few performance improvements: - Make the iteration in NonZero more efficient by using a raw pointer and simplifying the increment logic - add another unit test to check the new logic works with 3 dimensional tensor - gains about 2% for ssd_mobilenet - Avoid floating point operations on each iteration on Concat - about 0.5% for ssd_mobilenet and ssd_resnet34 - Put common case first in ExecutionFrame::AllocateAsPerAllocationPlan to avoid unnecessary call to IsSparseTensor - about 0.05% for ssd_mobilenet - Minor tweak to put some ctors in the TensorShape header so they can be inlined more easily * Fix race condition issue in RNN/LSTM/GRU (#1544) Fix race condition issue in RNN/LSTM/GRU. Description: The filter_desc and rnn_desc could also be changed in compute which could be in multi-thread. It will cause race condition issue. Fix: create temperate cudnn descriptors cache cudnn_dropout_desc_ which won't change * Remove memory copy between TensorRT and CUDA (#1561) * remove memory copy between CUDA and TRT * add info to RegisterExecutionProvider input * use new IDeviceAllocator for trt allocator * remove SetDefaultInputsMemoryType from TRT EP * remove onnx-tensorrt 5.0 * add submodule onnx-tensorrt branch 5.1 * remove redundancy * Update transformer_memcpy.cc * Update tensorrt_execution_provider.cc * switch to TensorRT 5.1.5.0 * update python binding * disable failed test case on TensorRT * Update activation_op_test.cc * upgrade to TensorRT container 19.06 * update according to feedback * add comments * remove tensorrt allocator and use cuda(gpu) allocator * update onnx-tensorrt submodule * change ci build cuda directory name * Optimize Fence checking performance (#1593) * For majority of nodes, we do not need to do fence check. Instead, we only need to do FenceCheck for CPU<->GPU mem sync node But we pay the Fence check cost for every single node and every single input and output. This change will minimize the Fence check to only do it when necessary. * Added license files in the base image (#1595) * Update Dockerfile.openvino * Update Dockerfile.cuda * Update Dockerfile.cuda * Update Dockerfile.openvino * Update Dockerfile.cuda * added ThirdParty notice file to base image. * corrected license file name * Implement new LabelEncoder in opset 2 in ML domain (#1393) * Implement new LabelEncoder in opset 2 in ML domain * Fix compilation error * Fix tests * Include ONNX's fix * Formatting and addressing a comment * Address a minor comment * add int64 support for less op. (#1604) * put all gemmlowp common code in one place (#1590) * put all gemmlowp common code in one place * fix gpu build failures * minor update * Update nGraph to v0.22.1 (#1582) * Update nGraph to 0.21 and adjust the EP * Share the graph initializers between custom ops * Update nGraph to 0.22 and exclude Gather entirely * Enable building on Windows with nGraph v0.21.1-rc.0 * Disable the unsigned input Shrink op tests for nGraph until the next update * Line-shortening code refactor * Fix for the master branch merge artifact * MKLDNN patches adjustment for Windows * Exclude MatMulInteger for non-const zero points * Exclude ConvInteger for non-const zero points * Enable full Cast op support * Use the v0.22.1 tag * Skip ConvTranspose_InvalidKernelShape test for ngraph provider * Create sub-graph ModelProto from fused_node * Include io_win32.h only if builds on windows (#1587) * Include io_win32.h only if builds on windows * looks like include order matters * Fix for CPU random ops seed narrowing conversion. (#1594) * Fix perf test executable. (#1598) * Mention OrtCreateSessionFromArray in C API doc * Fix perf test executable due to removal of certain C APIs * fix linux build * Avoid duplication * Fix mem leak * Minor perf improvements. (#1580) * Minor perf improvements. - Cache the vector sizes in IExecutionFrame and NodeIndexInfo to avoid calls to size(). - 2 instructions instead of 10 - Remove an unnecessary check in IExecutionFrame - add a check to the ctor so we guarantee it's unnecessary - Reserve memory for the vectors in BroadcastIterator - saves reallocs if more than one value is added - but rare with the mlperf models for multiple values to be added so benefit is limited. - slight tweak to the Broadcaster ctor code to make it more readable * Serialize optimized onnx model (#1470) * Model serialization * Removed duplicate symbol * Minor update * Review comments * add tests * Model serialization * Removed duplicate symbol * Minor update * Merged PR 1106437: Model Serialization in onnxruntime * Review comments * Merged PR 1107226: Review comments Review comments * add tests * Fixed merge conflict * Correct python tests * InferenceSesssion Refeed Test * Replace use of widechar const literal-L * Fixed failing tests * Updated comment * Removed unnecessary session options * Spell check on comments * Do not serialize when level 3 optimization specified * Updated error logs * Changed log severity to WARN * Fix log message truncation on Windows when printf formatting is used.` (#1599) * Fix log message truncation and add unit test. On Windows vnsprintf_s returns -1 when truncating so we need to differentiate that from a real error. * Remove copy of generator in Multinomial (#1611) * Remove copy of generator in Multinomial so that different values are generated each time. Add ability to test * Kezhan/execute graph refactoring (#1553) * checking execution provider logic updated. * fix the logic of copy input and output. * update * update * update * update * update * update * fix ngraph failure. * fix comments * Cleanup csharp API SessionOptions and RunOptions to be consistent with other APIs (#1570) - Updated SessionOptions API to use properties instead of setter/getter methods. - Added missing APIs. - Added RunOptions. * Make changes to pipeline template to include missing headers in tars/zips (#1617) * Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add doc… (#1623) * Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add documentation for ORT_TENSORRT* env vars. * Update TensorRT-ExecutionProvider.md * Use a friendly enum for graph optimization level. (#1586) * Mention OrtCreateSessionFromArray in C API doc * review changes * use enum for graph optimization level * Use explicit values for enums * updates... * Add friendly enum for graph optimization levels in C, C# and Python APIs. * Fix linux build * Fix build breakage due to master merge * PR comments * Generate documentation from the registered operator kernels (#1395) - Added python script for generating markdown doc from the registered opkernels. - Made some conditional changes in the pybind to expose necessary python API - Added some missing type-constraints in the op kernel registrations * Fix incorrect box offset computation in NMS op (#1624) * More changes * Fix NMS * nits * Integrate featurizers (#1573) Added Sample Featurizer and Infrastructure Make featurizers and unit tests compile and run with GTest. Create definitions for the first featurizer kernel. Add new operator domain. Create datetime_transformer kernel and build. Move OPAQUE types definitions for featurizers kerneles out to a separate cc. Register them with the type system. Provide unit tests for new AutoML DateTimeTransformer kernel. Make necessary adjustments to the test infrastructure to make it run with new types. * Support int64 for ReduceMax (#1625) * update onnx to latest commit (#1622) * update onnx to latest commit * Disable and/or fix failing tests * disable not yet implemented tests for opset 11 * disable tests * fix bug in mkldnn fp16 graph check * Copy System.Numerics.Tensors sources from dotnet/corefx into onnxruntime (#1605) Copy System.Numerics.Tensors sources from dotnet/corefx into onnxruntime * removed --gen_doc (#1633) * Fix parsing initial hidden state in RNN (#1626) * Fix the way initial hidden state is used for reverse direction in RNN * Add test case * Updates * Let mlas use session thread pool (#1609) 1.Let mlas use session thread pool 2.Remove onnxruntime_USE_MLAS cmake option 3. Remove the win32 thread pool code inside mlas mlas will: 1.use ort thread pool if it get passed in 2.use openmp if the threadpool parameter is nullptr 3.run single threaded if the threadpool parameter is nullptr and openmp is disabled. * update TRT EP CI's to use latest model.zip (#1637) * Add AutoML to 3 main builds. (#1631) Add AutoML to 3 main builds. Fix unit tests. Enable copy elision, do not move movable object on return by value. * MLAS: add U8U8 MatMul operation (#1644) Implement the first round of changes for quantization inside MLAS. This adds a MatMul operation for U8xU8=S32 for x86/x64 processors. * Add uint8 Support for NonZero Op (#1614) * update MKLML to version which contains fix for thread hang. (#1636) * update MKLML which has bugfix for thread hang. move PATCH_COMMAND outside BUILD_FOR_NATIVE_MACHINE check. * MKLML_VERSION 2020.0.20190813 is for windows only. * MlasGetMaximumThreadCount: plus 1 to the NumThreads from ORT thread pool (#1646) * Update perf tool documentation to reflect the new graph optimization enums. Relax constraint for enable_all. (#1650) * Allow user disable multiple threading (#1647) * Update onnx test runner documentation (#1651) * Mention OrtCreateSessionFromArray in C API doc * Update perf tool documentation to reflect the new graph optimization enums. Relax constraint for enable_all. * Update one more doc * Update onnx test runner documentation * Add default in the docs * Fix memory leak in mlas unitest (#1654) * fix bug on windows where ops were always getting dumped. (#1648) * Remove --whole-archive (#1655) * Check return value form CreateFeedsFetchesManager. (#1653) Also cleanup a couple of unused variables. * Update PyTorch Section for supported onnx version (#1635) PyTorch exporter in Pytorch1.2 can natively support multiple opset now * cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes Fix issue that cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes Solution: Reset the 0 sequence to 1 for the bathes before call the cudnnRNNForwardInferenceEx, has a array to track the batch id which has 0 sequence. Once get the result, call a CUDA kernel to mask on the output using the batch id tracked in the array. * Add details of which node was not able to be placed on an execution provider. (#1665) * nGraph EP Optimizations (#1630) * Added check for unnecessary function initializations, and removed lock from unneeded areas of code. * Added LRU cache to EP. * Bugfixes for nGraph EP Optimization PR * Changed default cache size to 500 and refactored mutex readability. * Fixed unsafe environmental variable fetch for Windows. * Cleaned up Windows environment functions and cleaned up mutexes. * Fix a few errors in the NuGet pipeline (still broken) (#1656) * update set fetches for execution with allocation plan. (#1668) * Support Tensor and Tensor in C# API. Support Tensor as input. Fix a bug in the InferenceSession Run() with RunOptions (#1671) - Support bool-Tensor and int8-Tensor in input-output of C# api - Support string-tensor as input in C# api - Fix a bug in InferenceSession.Run() -- RunOptions was not passed into the native call * Optimize kernel index (#1672) * update clip for opset 11 (#1661) * update clip for opset 11 * exclude ngraph provider for clip unit tests * exclude ngraph for all clip opset 11 tests * fix op version * Add support of ReduceSum int64 (#1664) * Add support of ReduceSum int64 * add unit test for int64 * int64 support for 'where' op (#1666) * Added some mo optimizations to improve performance (#1674) Signed-off-by: suryasidd * Don't create the default allocator every single time. Rename API accordingly. Expose Session/Run log severity levels. (#1615) * Mention OrtCreateSessionFromArray in C API doc * Don't create the default allocator every single time. Rename API accordingly. * Don't create the default allocator every single time. Rename API accordingly. * updates... * updates... * PR comments * fix typo in license header * fix build * Share default CPU allocator with Mlas preferred alignment (#1682) Description: make default CPU allocator to use MLAS preferred alignment Motivation and Context This is needed for C API to have an aligned default CPU allocator, the same as the one in CPU provider * More fixes on the NuGet CPU CI pipeline (#1688) - Fix the Windows end-to-end test in NuGet CI - Skip the TestModelSerialization, because it is failing on Linux. Must be fixed before API is released for use. Owner is notified. * treat zero point properly (#1686) * use MLAS for QGEMM in matmulInteger and convInteger (#1692) * use mlas qgemm for u8u8_s32 gemms * update test * fix typo in max batch size error msg. (#1687) * Python API naming and other cleanup (#1678) - Make the naming of properties in python SessionOptions and RunOptions consistent with other apis. - Remove unnecessary apis * make gemmlowp default for arm (#1701) * make gemmlowp default for arm * force use_gemmlowp in header for default case * remove unnecessary white space * Doc updates (#1522) * Updates * Remove preview texts * Update README.md * Updates * Update README.md * Update README.md * Minor wording update * Update README.md * Update doc on CUDA version * revert update * Update readme for issue #1558 * Clean up example section * Cosmetic updates - Add a index of build instructions for browsability - Update build CUDA version from 9.1 to 10 * Fix broken link * Update README to reflect upgrade to pip requirement * Update CuDNN version for Linux Python packages * Clean up content Updated ordering and add table of contents * Minor format fixes * Move Android NNAPI under EP section * Add link to operator support documentation * Fix typo * typo fix * remove todo section * remove @PCGOTREL x64 usage (#1707) Avoid the need for @PCGOTREL relocations by annotating MLAS global data shared with assembly modules with attribute(visibility("hidden")). * MLAS: Android sgemm kernel build fix (#1710) Fix the aarch64 kernel to build properly with the Android NDK (specifically clang). * Remove TaskThreadPool (#1713) * Allow input used across execution providers as long as they use the same allocator device (#1715) as long as these providers use the same allocator device Description: Currently ORT throws error when one input is used in different EPs. The change removes that restriction Motivation and Context It is now possible to share inputs across EPs now that allocation are device-based, instead of EP based. * Add support for int8 x uint8 for MatMulInteger, and int16 x int16 custom op (#1391) Description: The change adds necessary quantization support on CPU with mixed int8/uint8, as well as int16 for matrix multiply operations that outputs int32 Motivation and Context Integer operations are critical for quantized model's performance Current MatMulInteger implementation in CPU only supports uint8 x uint8, while the spec supports int8 x uint8. Having a default CPU implementation that fully support the spec would help accuracy verification. Besides, some model may need to quantize to int16, but MatMulInteger op does not support that yet. A custom op of MatMulInteger16 is added to satisfy such models. * Use exec form of ENTRYPOINT for docker server (#1690) * Use exec form of ENTRYPOINT for docker server # Issue The entrypoint currently uses the shell form - this prevents users from passing in any cmdline arguments... also passing a model_path in means the server only works in the envvar is set... however this is not what the error message says! ``` $ docker run -v /home/rakelkar/try/onnxzoo/style:/mnt/models -it mcr.microsoft.com/onnxruntime/server --model_path /mnt/models/model.onnx Version: local_build Commit ID: default model_path must be the location of a valid file Allowed options: -h [ --help ] Shows a help message and exits --log_level arg (=info) Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal --model_path arg Path to ONNX model --address arg (=0.0.0.0) The base HTTP address --http_port arg (=8001) HTTP port to listen to requests --num_http_threads arg (=4) Number of http threads --grpc_port arg (=50051) GRPC port to listen to requests ``` # Fix 1. remove the env var 2. use the exec form * Update readme to use model_path arg * Support 'Bilinear' mode for 2D inputs in Resize and Upsample kernels (#1679) * Support bilinear mode with actual 2D inputs in Resize and upsample * Fix build break * Fix build break * Add test * CUDA changes * Resolve PR comments * Resolve comments * add implementation for dynamic quantize linear (#1697) * Fix reading of onnx domain causing one of the automl models to break in 0.5 release. (#1694) * Mention OrtCreateSessionFromArray in C API doc * Fix registration of Equal op causing one of the automl models to break in 0.5 release. * updates... * Fix a issue that CUDA EP fallback to much nodes to CPU for some case which cause huge data copy. If the node's inputs are all initializer, we shouldn't fallback the node to CPU. (#1727) Fix an issue that CUDA EP fallback too much nodes to CPU for some case which cause huge data copy. https://github.com/microsoft/onnxruntime/issues/1675 Currently, if the node's inputs are all as initialier, CUDA EP will fallback it to CPU. And it will also fallback some nodes under it. It could cause some huge data copy. for the case reported by a user, it has several Slices with input from initializer, and a Concat op to concat the output from Slice output. The data is huge 16MB after concat, which make the data copy from CPU to GPU quite costly because it's a sync copy. Fix If the node's inputs are all initializer, we shouldn't fallback the node to CPU. * Publish perf tool with nightly build (#1728) * Update the docker file for OpenVINO (#1741) Update the docker file for OpenVINO which is used for AML * Fix typo in NMS code Fix typo in NMS code * MKL-DNN EP: control flow fix (#1740) * moved subgraph_index to MklDnn Execution Provider * code cleanup * Implementation of Nuphar execution provider (#881) * Implement Nuphar execution provider Nuphar execution provider is a TVM-based compilation provider. It has shown great speedups for RNN models using Scan. This PR is mainly for a preview of the shared codegen library for other TVM-based providers. * Fix submodules * Fix TVM submodule * Update Nuphar to latest and resolve confliction * Remove stale files caused by merge -X theirs * Revert heap buffer change to not introduce onnxruntime_framework into onnxruntime_perf_test * Fix bad merge * Merge from Nuphar * Fix warning treated as error, revert some unnecessary changes * Revert some more test changes * Some more test revert or comments to make review easier New tests could be added later * One more revert of unnecessary changes * More change revert. Test could be added back later. * Enforce shape validation. (#1716) * Mention OrtCreateSessionFromArray in C API doc * Enforce shape validation. * Update broken models * enable quantizing specific nodes (#1742) * update quantization script --- .gitmodules | 8 +- BUILD.md | 328 +- README.md | 93 +- cgmanifest.json | 2 +- cmake/CMakeLists.txt | 26 +- cmake/external/mkldnn.cmake | 10 +- cmake/external/ngraph.cmake | 8 +- cmake/external/onnx | 2 +- cmake/external/onnx-tensorrt | 2 +- cmake/external/tvm | 2 +- cmake/onnxruntime.cmake | 16 +- cmake/onnxruntime_automl_featurizers.cmake | 44 + cmake/onnxruntime_common.cmake | 9 +- cmake/onnxruntime_graph.cmake | 8 + cmake/onnxruntime_mlas.cmake | 59 +- cmake/onnxruntime_nuphar_extern.cmake | 39 + cmake/onnxruntime_providers.cmake | 87 +- cmake/onnxruntime_python.cmake | 13 + cmake/onnxruntime_unittests.cmake | 40 +- cmake/onnxruntime_util.cmake | 11 +- .../ngraph/ngraph_fix_install_error.patch | 127 - .../ngraph/ngraph_fix_library_path.patch | 33 - .../ngraph_fix_mkldnn_missing_symbol.patch | 64 + csharp/OnnxRuntime.CSharp.proj | 10 +- .../OnnxRuntime.snk | Bin .../Program.cs | 4 +- .../DisposableNamedOnnxValue.cs | 11 +- .../InferenceSession.cs | 160 +- .../Microsoft.ML.OnnxRuntime.csproj | 72 +- .../NamedOnnxValue.cs | 137 +- .../NativeMemoryAllocator.cs | 19 +- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 65 +- .../Microsoft.ML.OnnxRuntime/OnnxRuntime.cs | 32 +- .../Microsoft.ML.OnnxRuntime/RunOptions.cs | 120 + .../SessionOptions.cs | 319 +- .../Tensors/ArrayTensorExtensions.cs | 66 + .../Tensors/ArrayUtilities.cs | 227 + .../Tensors/DenseTensor.cs | 188 + .../Tensors/Tensor.cs | 1311 ++ .../CXX_Api_Sample.cpp | 11 +- .../C_Api_Sample.cpp | 9 +- .../InferenceTest.cs | 183 +- .../Microsoft.ML.OnnxRuntime.Tests.csproj | 32 + .../Tensors/NativeMemory.cs | 119 + .../Tensors/TensorArithmetic.cs | 16201 ++++++++++++++++ .../Tensors/TensorArithmetic.tt | 249 + .../Tensors/TensorExtensions.cs | 42 + .../Tensors/TensorOperations.cs | 750 + .../Tensors/TensorOperations.tt | 251 + .../Tensors/TensorTemplate.ttinclude | 328 + .../Tensors/TensorTests.cs | 2243 +++ .../Tensors/TensorTestsBase.cs | 164 + csharp/testdata/test_types_BOOL.pb | Bin 167 -> 151 bytes csharp/testdata/test_types_INT8.pb | Bin 167 -> 151 bytes csharp/testdata/test_types_STRING.pb | Bin 167 -> 151 bytes .../Program.cs | 23 +- dockerfiles/Dockerfile.cuda | 8 +- dockerfiles/Dockerfile.openvino | 24 +- dockerfiles/Dockerfile.server | 3 +- dockerfiles/Dockerfile.source | 7 +- dockerfiles/Dockerfile.tensorrt | 8 +- dockerfiles/README.md | 21 +- .../{ => scripts}/install_common_deps.sh | 8 +- docs/ONNX_Runtime_Perf_Tuning.md | 2 +- docs/OperatorKernels.md | 470 + docs/Versioning.md | 2 +- .../Nuphar-ExecutionProvider.md | 142 + .../TensorRT-ExecutionProvider.md | 26 +- include/onnxruntime/core/common/callback.h | 17 - .../onnxruntime/core/framework/allocator.h | 6 + .../onnxruntime/core/framework/data_types.h | 3 + .../core/framework/kernel_def_builder.h | 6 + .../core/framework/kernel_registry.h | 8 + .../onnxruntime/core/framework/op_kernel.h | 7 +- include/onnxruntime/core/framework/tensor.h | 4 +- .../onnxruntime/core/framework/tensor_shape.h | 9 +- include/onnxruntime/core/graph/constants.h | 1 + include/onnxruntime/core/graph/graph.h | 1 - include/onnxruntime/core/graph/graph_viewer.h | 11 +- .../onnxruntime/core/platform/threadpool.h | 23 +- .../nuphar/nuphar_provider_factory.h | 17 + .../tensorrt/tensorrt_provider_factory.h | 2 +- .../core/session/onnxruntime_c_api.h | 79 +- .../core/session/onnxruntime_cxx_api.h | 21 +- .../core/session/onnxruntime_cxx_inline.h | 24 +- onnxruntime/__init__.py | 2 +- onnxruntime/automl_ops/automl_featurizers.h | 8 + onnxruntime/automl_ops/automl_types.cc | 39 + onnxruntime/automl_ops/automl_types.h | 13 + .../automl_ops/cpu/datetime_transformer.cc | 42 + onnxruntime/automl_ops/cpu_automl_kernels.cc | 25 + onnxruntime/automl_ops/cpu_automl_kernels.h | 13 + .../cpu/attnlstm/attention_wrapper.cc | 25 +- .../cpu/attnlstm/attention_wrapper.h | 4 +- .../cpu/attnlstm/bahdanau_attention.cc | 34 +- .../cpu/attnlstm/bahdanau_attention.h | 3 +- .../cpu/attnlstm/deep_cpu_attn_lstm.cc | 23 +- .../cpu/attnlstm/uni_dir_attn_lstm.cc | 8 +- .../cpu/attnlstm/uni_dir_attn_lstm.h | 4 +- .../contrib_ops/cpu/matmul_integer16.cc | 45 + .../contrib_ops/cpu/matmul_integer16.h | 22 + onnxruntime/contrib_ops/cpu/nchwc_ops.cc | 3 - onnxruntime/contrib_ops/cpu/nchwc_ops.h | 2 + .../contrib_ops/cpu/word_conv_embedding.cc | 12 +- .../contrib_ops/cpu/word_conv_embedding.h | 5 +- .../contrib_ops/cpu_contrib_kernels.cc | 2 + .../src/FeaturizerPrep/Featurizer.h | 163 + .../Featurizers/DateTimeFeaturizer.cpp | 56 + .../Featurizers/DateTimeFeaturizer.h | 101 + .../FeaturizerPrep/Featurizers/SampleAdd.cpp | 40 + .../FeaturizerPrep/Featurizers/SampleAdd.h | 95 + .../Featurizers/UnitTests/CMakeLists.txt | 48 + .../DateTimeFeaturizer_UnitTests.cpp | 125 + .../UnitTests/SampleAdd_UnitTest.cpp | 22 + .../Featurizers/UnitTests/code_coverage.yaml | 5 + .../featurizers/src/FeaturizerPrep/Traits.h | 218 + .../FeaturizerPrep/UnitTests/CMakeLists.txt | 41 + .../UnitTests/Featurizer_UnitTest.cpp | 104 + .../UnitTests/Traits_UnitTests.cpp | 40 + .../FeaturizerPrep/UnitTests/test_main.cpp | 18 + onnxruntime/core/codegen/common/common.cc | 11 +- onnxruntime/core/codegen/common/creator.h | 2 +- onnxruntime/core/codegen/common/dispatcher.h | 2 + onnxruntime/core/codegen/common/profile.h | 2 +- onnxruntime/core/codegen/common/registry.h | 2 + onnxruntime/core/codegen/common/settings.cc | 4 + onnxruntime/core/codegen/common/settings.h | 1 + onnxruntime/core/codegen/mti/math/gemm.cc | 6 +- .../core/codegen/mti/math/matmul_ops.cc | 37 +- .../core/codegen/mti/math/matmul_ops.h | 5 + .../core/codegen/mti/math/unary_ops.cc | 29 +- onnxruntime/core/codegen/mti/mti_tvm_utils.cc | 34 + onnxruntime/core/codegen/mti/mti_tvm_utils.h | 3 + .../math/quantize/matmul_integer.cc | 18 +- .../passes/op_ir_creator/tensor/crop.cc | 3 +- .../passes/op_ir_creator/tensor/transpose.cc | 10 +- .../passes/op_ir_creator/tvm_op_creator.h | 2 +- .../codegen/passes/scheduler/tvm_scheduler.h | 2 +- .../codegen/passes/utils/ort_tvm_utils.cc | 14 +- .../passes/weight_layout/weight_layout.h | 2 +- onnxruntime/core/common/logging/capture.cc | 14 +- onnxruntime/core/common/profiler.cc | 18 + onnxruntime/core/common/profiler.h | 19 + onnxruntime/core/common/task_thread_pool.h | 213 - onnxruntime/core/common/threadpool.cc | 198 +- .../core/framework/allocation_planner.cc | 99 +- onnxruntime/core/framework/allocator.cc | 55 +- onnxruntime/core/framework/bfc_arena.h | 2 +- onnxruntime/core/framework/callback.cc | 10 +- onnxruntime/core/framework/callback.h | 15 + onnxruntime/core/framework/data_types.cc | 41 + onnxruntime/core/framework/error_code.cc | 5 +- onnxruntime/core/framework/execution_frame.cc | 97 +- onnxruntime/core/framework/execution_frame.h | 7 +- .../core/framework/feeds_fetches_manager.h | 3 +- .../core/framework/graph_partitioner.cc | 5 - .../core/framework/kernel_registry_manager.cc | 18 +- onnxruntime/core/framework/mem_pattern.h | 4 +- onnxruntime/core/framework/node_index_info.cc | 4 + onnxruntime/core/framework/node_index_info.h | 8 +- .../framework/op_kernel_context_internal.h | 4 +- .../core/framework/parallel_executor.cc | 76 +- .../core/framework/parallel_executor.h | 1 - onnxruntime/core/framework/run_options.cc | 10 + .../framework/sequential_execution_plan.h | 9 + .../core/framework/sequential_executor.cc | 72 +- onnxruntime/core/framework/session_state.cc | 149 +- onnxruntime/core/framework/session_state.h | 60 +- .../framework/session_state_initializer.cc | 139 +- .../framework/session_state_initializer.h | 9 +- onnxruntime/core/framework/tensor.cc | 4 +- onnxruntime/core/framework/tensor_shape.cc | 12 +- .../core/framework/tensorprotoutils.cc | 4 +- onnxruntime/core/framework/utils.cc | 162 +- onnxruntime/core/framework/utils.h | 2 + .../core/graph/automl_ops/automl_defs.cc | 46 + .../core/graph/automl_ops/automl_defs.h | 30 + .../core/graph/contrib_ops/contrib_defs.cc | 39 +- onnxruntime/core/graph/graph_viewer.cc | 23 +- onnxruntime/core/graph/model.cc | 13 +- onnxruntime/core/mlas/inc/mlas.h | 21 + .../aarch64/{sgemma.s => SgemmKernelNeon.S} | 26 +- .../mlas/lib/amd64/AssembleAvx512Vnni.inc | 232 + .../mlas/lib/amd64/QgemmU8U8KernelAvx2.asm | 1241 ++ .../lib/amd64/QgemmU8U8KernelAvx512BW.asm | 114 + .../lib/amd64/QgemmU8U8KernelAvx512Common.inc | 385 + .../lib/amd64/QgemmU8U8KernelAvx512Vnni.asm | 91 + .../arm64/{sgemma.asm => SgemmKernelNeon.asm} | 41 +- onnxruntime/core/mlas/lib/erf.cpp | 2 +- onnxruntime/core/mlas/lib/logistic.cpp | 2 +- onnxruntime/core/mlas/lib/mlasi.h | 91 +- onnxruntime/core/mlas/lib/platform.cpp | 45 +- onnxruntime/core/mlas/lib/qgemm.cpp | 599 + onnxruntime/core/mlas/lib/sgemm.cpp | 2 +- onnxruntime/core/mlas/lib/tanh.cpp | 2 +- onnxruntime/core/mlas/lib/threading.cpp | 92 +- .../core/mlas/lib/x86_64/AssembleAvx512Vnni.h | 238 + .../core/mlas/lib/x86_64/ErfKernelFma3.S | 8 +- .../core/mlas/lib/x86_64/LogisticKernelFma3.S | 5 +- .../mlas/lib/x86_64/QgemmU8U8KernelAvx2.S | 1121 ++ .../mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S | 120 + .../lib/x86_64/QgemmU8U8KernelAvx512Common.h | 361 + .../lib/x86_64/QgemmU8U8KernelAvx512Vnni.S | 95 + .../core/mlas/lib/x86_64/SconvKernelAvx.S | 6 + .../core/mlas/lib/x86_64/SconvKernelAvx512F.S | 3 + .../core/mlas/lib/x86_64/SconvKernelSse2.S | 3 + .../core/mlas/lib/x86_64/SgemmKernelAvx.S | 15 +- .../core/mlas/lib/x86_64/SgemmKernelFma3.S | 9 +- .../core/mlas/lib/x86_64/SgemmKernelM1Avx.S | 5 +- .../lib/x86_64/SgemmKernelM1TransposeBAvx.S | 5 +- .../core/mlas/lib/x86_64/TanhKernelFma3.S | 5 +- .../optimizer/optimizer_execution_frame.cc | 4 +- .../optimizer/optimizer_execution_frame.h | 2 +- .../core/optimizer/transformer_memcpy.cc | 11 +- onnxruntime/core/platform/env.h | 2 +- onnxruntime/core/platform/posix/env.cc | 8 +- onnxruntime/core/platform/windows/env.cc | 2 +- onnxruntime/core/providers/common.h | 22 + .../core/providers/cpu/controlflow/loop.cc | 1 - .../providers/cpu/controlflow/scan_utils.cc | 1 - .../core/providers/cpu/controlflow/utils.h | 3 +- .../providers/cpu/cpu_execution_provider.cc | 66 +- .../core/providers/cpu/generator/random.cc | 119 +- .../core/providers/cpu/generator/random.h | 55 +- onnxruntime/core/providers/cpu/math/clip.cc | 9 +- onnxruntime/core/providers/cpu/math/clip.h | 36 +- .../providers/cpu/math/element_wise_ops.cc | 54 +- .../providers/cpu/math/element_wise_ops.h | 33 +- onnxruntime/core/providers/cpu/math/gemm.h | 8 +- .../core/providers/cpu/math/logsoftmax.cc | 7 +- onnxruntime/core/providers/cpu/math/matmul.cc | 7 +- .../core/providers/cpu/math/matmul_helper.h | 5 +- .../core/providers/cpu/math/matmul_integer.cc | 122 +- .../core/providers/cpu/math/matmul_integer.h | 6 +- .../cpu/math/quantize_linear_matmul.cc | 104 +- .../cpu/math/quantize_linear_matmul.h | 3 +- .../core/providers/cpu/math/softmax.cc | 6 +- .../core/providers/cpu/math/softmax_shared.cc | 5 +- .../core/providers/cpu/math/softmax_shared.h | 5 +- .../core/providers/cpu/ml/label_encoder.cc | 109 +- .../core/providers/cpu/ml/label_encoder.h | 62 + onnxruntime/core/providers/cpu/nn/Unpool.cc | 6 +- onnxruntime/core/providers/cpu/nn/conv.cc | 23 +- .../core/providers/cpu/nn/conv_integer.cc | 67 +- .../core/providers/cpu/nn/conv_transpose.cc | 9 +- onnxruntime/core/providers/cpu/nn/pool.cc | 2 +- onnxruntime/core/providers/cpu/nn/pool_base.h | 13 +- .../core/providers/cpu/nn/qlinearconv.cc | 110 +- .../core/providers/cpu/nn/qlinearconv.h | 42 +- .../object_detection/non_max_suppression.cc | 10 +- .../cpu/object_detection/roialign.cc | 2 +- .../providers/cpu/reduction/reduction_ops.cc | 10 + .../core/providers/cpu/rnn/deep_cpu_gru.cc | 29 +- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 32 +- .../core/providers/cpu/rnn/deep_cpu_lstm.h | 4 +- onnxruntime/core/providers/cpu/rnn/rnn.cc | 19 +- .../core/providers/cpu/rnn/rnn_helpers.h | 6 +- onnxruntime/core/providers/cpu/symbols.txt | 9 +- .../core/providers/cpu/tensor/cast_op.cc | 4 +- .../core/providers/cpu/tensor/compress.cc | 3 +- .../core/providers/cpu/tensor/concat.cc | 43 +- .../cpu/tensor/dynamicquantizelinear.cc | 75 + .../cpu/tensor/dynamicquantizelinear.h | 20 + .../core/providers/cpu/tensor/identity_op.cc | 3 +- .../core/providers/cpu/tensor/nonzero_op.cc | 44 +- .../core/providers/cpu/tensor/onehot.cc | 5 +- .../providers/cpu/tensor/quantize_linear.cc | 37 +- onnxruntime/core/providers/cpu/tensor/size.cc | 3 +- onnxruntime/core/providers/cpu/tensor/tile.cc | 3 +- .../core/providers/cpu/tensor/upsample.cc | 69 +- .../core/providers/cpu/tensor/upsample.h | 7 +- .../core/providers/cpu/tensor/where_op.cc | 2 +- .../core/providers/cuda/cuda_allocator.cc | 3 +- .../core/providers/cuda/cuda_allocator.h | 6 +- .../providers/cuda/cuda_execution_provider.cc | 21 +- .../core/providers/cuda/cudnn_common.h | 2 +- .../cuda/math/binary_elementwise_ops.cc | 40 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 159 +- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 82 +- onnxruntime/core/providers/cuda/rnn/gru.h | 3 +- onnxruntime/core/providers/cuda/rnn/lstm.h | 3 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 6 +- .../core/providers/cuda/rnn/rnn_impl.cu | 50 +- .../core/providers/cuda/rnn/rnn_impl.h | 7 + .../core/providers/cuda/tensor/compress.cc | 3 +- .../core/providers/cuda/tensor/resize_impl.cu | 71 +- .../core/providers/cuda/tensor/tile.cc | 3 +- .../core/providers/cuda/tensor/upsample.cc | 28 +- .../providers/cuda/tensor/upsample_impl.cu | 68 +- .../mkldnn/mkldnn_execution_provider.cc | 8 +- .../mkldnn/mkldnn_execution_provider.h | 2 + .../mkldnn/mkldnn_provider_factory.cc | 1 + .../core/providers/mkldnn/subgraph/subgraph.h | 6 +- .../core/providers/ngraph/ngraph_custom_op.cc | 95 +- .../core/providers/ngraph/ngraph_custom_op.h | 7 +- .../ngraph/ngraph_execution_provider.cc | 189 +- .../ngraph/ngraph_execution_provider.h | 9 +- .../nuphar/common/analysis/analysis.h | 45 + .../nuphar/common/analysis/graph_stats.h | 77 + .../common/analysis/output_alias_analysis.cc | 109 + .../common/analysis/output_alias_analysis.h | 43 + .../nuphar/common/analysis/shape_expr.h | 243 + .../common/analysis/subgraph_codegen_stats.cc | 63 + .../common/analysis/subgraph_codegen_stats.h | 34 + .../analysis/subgraph_partition_stats.cc | 28 + .../analysis/subgraph_partition_stats.h | 31 + .../common/analysis/use_count_analysis.cc | 264 + .../common/analysis/use_count_analysis.h | 83 + .../nuphar/common/nuphar_settings.cc | 132 + .../providers/nuphar/common/nuphar_settings.h | 48 + .../providers/nuphar/common/nuphar_subgraph.h | 106 + .../nuphar/common/nuphar_tvm_utils.cc | 174 + .../nuphar/common/nuphar_tvm_utils.h | 26 + .../core/providers/nuphar/common/utils.cc | 76 + .../core/providers/nuphar/common/utils.h | 23 + .../nuphar/compiler/codegen_manager.cc | 233 + .../nuphar/compiler/codegen_manager.h | 43 + .../providers/nuphar/compiler/func_info.cc | 562 + .../providers/nuphar/compiler/func_info.h | 122 + .../nuphar/compiler/initializer_info.h | 34 + .../nuphar/compiler/nuphar_codegen_ctx.cc | 247 + .../nuphar/compiler/nuphar_codegen_ctx.h | 147 + .../nuphar/compiler/nuphar_compiler.cc | 229 + .../nuphar/compiler/nuphar_compiler.h | 65 + .../providers/nuphar/compiler/nuphar_handle.h | 40 + .../nuphar/compiler/nuphar_op_ir_builder.cc | 311 + .../nuphar/compiler/nuphar_op_ir_builder.h | 34 + .../compiler/nuphar_schedule_builder.cc | 77 + .../nuphar/compiler/nuphar_schedule_builder.h | 20 + .../nuphar/compiler/traverse_shape_infer.cc | 128 + .../nuphar/compiler/traverse_shape_infer.h | 49 + .../compiler/x86/op_ir_creator/all_ops.h | 64 + .../compiler/x86/op_ir_creator/math/gemm.cc | 52 + .../x86/op_ir_creator/math/logsoftmax.cc | 32 + .../compiler/x86/op_ir_creator/math/matmul.cc | 148 + .../math/quantize/matmul_integer.cc | 126 + .../x86/op_ir_creator/math/reduce_ops.cc | 169 + .../x86/op_ir_creator/math/softmax.cc | 32 + .../x86/op_ir_creator/math/unary_ops.cc | 124 + .../x86/op_ir_creator/tensor/slice.cc | 72 + .../compiler/x86/op_ir_creator/tensor/tile.cc | 34 + .../x86/scheduler/analysis_schedule.cc | 31 + .../x86/scheduler/nuphar_scheduler.cc | 53 + .../compiler/x86/scheduler/nuphar_scheduler.h | 41 + .../x86/scheduler/ort_type_schedule.cc | 270 + .../x86/scheduler/partial_schedule.cc | 21 + .../scheduler/tensorize/intrin_gemv_16bit.cc | 100 + .../scheduler/tensorize/intrin_gemv_16bit.h | 20 + .../scheduler/tensorize/intrin_gemv_8bit.cc | 104 + .../scheduler/tensorize/intrin_gemv_8bit.h | 20 + .../tensorize/intrin_gemv_ll_extern.cc | 103 + .../tensorize/intrin_gemv_ll_extern.h | 13 + .../scheduler/tensorize/intrin_gemv_ll_ir.cc | 96 + .../scheduler/tensorize/intrin_gemv_ll_ir.h | 20 + .../x86/scheduler/tensorize/ll/gemv_impl.cpp | 18 + .../x86/scheduler/tensorize/ll/gemv_impl.h | 137 + .../x86/scheduler/tensorize/tensorize_base.h | 76 + .../tensorize/tensorize_utilities.cc | 73 + .../scheduler/tensorize/tensorize_utilities.h | 30 + .../x86/scheduler/tensorize_schedule.cc | 144 + .../x86/scheduler/tvm_rule_schedule.cc | 120 + .../nuphar/compiler/x86/x86_target_info.cc | 19 + .../nuphar/compiler/x86/x86_target_info.h | 33 + .../providers/nuphar/extern/igemv_avx2.cc | 740 + .../core/providers/nuphar/extern/igemv_avx2.h | 46 + .../core/providers/nuphar/extern/igemv_mkl.cc | 36 + .../core/providers/nuphar/extern/igemv_mkl.h | 30 + onnxruntime/core/providers/nuphar/kernel.cc | 230 + onnxruntime/core/providers/nuphar/kernel.h | 152 + .../nuphar/mti_x86/math/halide_ops.cc | 307 + .../nuphar/mti_x86/math/halide_ops.h | 49 + .../nuphar/mti_x86/math/logsoftmax.cc | 16 + .../nuphar/mti_x86/math/logsoftmax.h | 15 + .../nuphar/mti_x86/math/matmul_ops.cc | 232 + .../nuphar/mti_x86/math/matmul_ops.h | 24 + .../nuphar/mti_x86/math/reduce_ops.cc | 356 + .../nuphar/mti_x86/math/reduce_ops.h | 40 + .../providers/nuphar/mti_x86/math/softmax.cc | 16 + .../providers/nuphar/mti_x86/math/softmax.h | 15 + .../nuphar/mti_x86/math/softmax_internal.cc | 56 + .../nuphar/mti_x86/math/softmax_internal.h | 17 + .../nuphar/mti_x86/math/unary_ops.cc | 186 + .../providers/nuphar/mti_x86/math/unary_ops.h | 25 + .../mti_x86/quantize/imatmul16_extern.cc | 149 + .../mti_x86/quantize/imatmul16_extern.h | 29 + .../nuphar/mti_x86/quantize/imatmul_extern.cc | 143 + .../nuphar/mti_x86/quantize/imatmul_extern.h | 29 + .../nuphar/nuphar_execution_provider.cc | 410 + .../nuphar/nuphar_execution_provider.h | 163 + .../nuphar/nuphar_provider_factory.cc | 36 + .../nuphar/partition/graph_partitioner.cc | 190 + .../nuphar/partition/graph_partitioner.h | 48 + .../providers/nuphar/partition/partitioner.cc | 266 + .../providers/nuphar/partition/partitioner.h | 100 + .../nuphar/partition/subgraph_partitioner.cc | 403 + .../nuphar/partition/subgraph_partitioner.h | 56 + .../providers/nuphar/runtime/compute_ctx.cc | 43 + .../providers/nuphar/runtime/compute_ctx.h | 233 + .../runtime/control_flow/loop_exec_ctx.h | 43 + .../runtime/control_flow/scan_exec_ctx.cc | 530 + .../runtime/control_flow/scan_exec_ctx.h | 87 + .../providers/nuphar/runtime/exec_block.cc | 29 + .../providers/nuphar/runtime/exec_block.h | 54 + .../core/providers/nuphar/runtime/handle.h | 24 + .../nuphar/runtime/sequential/basic.cc | 195 + .../nuphar/runtime/sequential/basic.h | 34 + .../nuphar/runtime/sequential/loop.cc | 72 + .../nuphar/runtime/sequential/loop.h | 30 + .../core/providers/nuphar/runtime/utils.h | 72 + .../core/providers/nuphar/scripts/README.md | 25 + .../nuphar/scripts/cntk_converter.py | 81 + .../nuphar/scripts/create_shared.cmd | 65 + .../providers/nuphar/scripts/create_shared.sh | 64 + .../providers/nuphar/scripts/model_editor.py | 629 + .../nuphar/scripts/model_quantizer.py | 310 + .../providers/nuphar/scripts/node_factory.py | 153 + .../providers/nuphar/scripts/rnn_benchmark.py | 205 + .../nuphar/scripts/symbolic_shape_infer.py | 643 + onnxruntime/core/providers/nuphar/symbols.txt | 1 + .../openvino/openvino_mo/openvino_mo.py | 36 +- .../providers/tensorrt/tensorrt_allocator.h | 32 - .../tensorrt/tensorrt_execution_provider.cc | 113 +- .../tensorrt/tensorrt_execution_provider.h | 4 +- .../tensorrt/tensorrt_provider_factory.cc | 17 +- .../core/session/abi_session_options.cc | 38 +- .../session/default_cpu_allocator_c_api.cc | 14 +- onnxruntime/core/session/environment.cc | 7 + onnxruntime/core/session/inference_session.cc | 49 +- onnxruntime/core/session/inference_session.h | 11 +- onnxruntime/core/session/onnxruntime_c_api.cc | 67 +- onnxruntime/core/util/gemmlowp_common.cc | 49 + onnxruntime/core/util/gemmlowp_common.h | 65 + .../core/util/gemmlowp_common_wrapper.h | 2 + onnxruntime/core/util/math.h | 9 +- onnxruntime/core/util/math_cpu.cc | 55 +- onnxruntime/core/util/math_cpuonly.h | 14 +- .../core/util/protobuf_parsing_utils.cc | 2 + onnxruntime/core/util/qmath.cc | 33 + onnxruntime/core/util/qmath.h | 38 + .../python/onnxruntime_pybind_state.cc | 296 +- .../python/tools/quantization/quantize.py | 298 +- onnxruntime/server/environment.cc | 2 +- onnxruntime/server/executor.cc | 1 - .../automl_ops/datetimetransformer_test.cc | 97 + .../test/common/logging/logging_test.cc | 38 +- .../test/contrib_ops/matmul_integer16_test.cc | 41 + .../test/framework/TestAllocatorManager.cc | 4 +- .../test/framework/allocation_planner_test.cc | 43 +- .../framework/cuda/allocator_cuda_test.cc | 4 +- .../test/framework/execution_frame_test.cc | 83 +- .../test/framework/inference_session_test.cc | 83 + onnxruntime/test/framework/math_test.cc | 58 +- .../test/framework/session_state_test.cc | 147 +- .../test_tensor_loader.cc | 98 +- onnxruntime/test/framework/test_utils.cc | 3 +- onnxruntime/test/mlas/unittest.cpp | 309 +- onnxruntime/test/onnx/README.txt | 2 +- onnxruntime/test/onnx/TestCase.cc | 79 +- onnxruntime/test/onnx/TestCase.h | 5 +- onnxruntime/test/onnx/callback.cc | 16 + onnxruntime/test/onnx/callback.h | 17 + onnxruntime/test/onnx/heap_buffer.cc | 8 +- onnxruntime/test/onnx/heap_buffer.h | 8 +- onnxruntime/test/onnx/main.cc | 92 +- onnxruntime/test/onnx/mem_buffer.h | 27 + .../test/onnx/microbenchmark/model_init.cc | 222 - .../test/onnx/microbenchmark/modeltest.cc | 3 +- onnxruntime/test/onnx/runner.cc | 82 +- onnxruntime/test/onnx/tensorprotoutils.cc | 459 + onnxruntime/test/onnx/tensorprotoutils.h | 39 + .../test/optimizer/graph_transform_test.cc | 5 +- onnxruntime/test/perftest/README.md | 2 +- onnxruntime/test/perftest/TFModelInfo.cc | 2 +- .../test/perftest/command_args_parser.cc | 41 +- onnxruntime/test/perftest/ort_test_session.cc | 10 +- .../test/perftest/performance_runner.h | 2 +- .../test/perftest/test_configuration.h | 4 +- .../cpu/activation/activation_op_test.cc | 5 +- .../providers/cpu/generator/random_test.cc | 51 +- .../test/providers/cpu/math/clip_test.cc | 40 +- .../cpu/math/element_wise_ops_test.cc | 8 + .../providers/cpu/math/matmul_integer_test.cc | 63 +- .../test/providers/cpu/math/softmax_test.cc | 8 +- .../providers/cpu/ml/label_encoder_test.cc | 126 + .../cpu/nn/conv_transpose_op_test.cc | 2 + .../test/providers/cpu/nn/shrink_test.cc | 12 +- .../non_max_suppression_test.cc | 37 +- .../cpu/reduction/reduction_ops_test.cc | 35 + .../providers/cpu/rnn/deep_cpu_gru_op_test.cc | 33 + .../cpu/rnn/deep_cpu_lstm_op_test.cc | 40 + .../test/providers/cpu/rnn/rnn_op_test.cc | 54 +- .../tensor/dynamic_quantize_linear_test.cc | 51 + .../providers/cpu/tensor/nonzero_op_test.cc | 19 + .../providers/cpu/tensor/onehot_op_test.cc | 28 + .../cpu/tensor/quantize_linear_test.cc | 12 +- .../providers/cpu/tensor/resize_op_test.cc | 49 +- .../providers/cpu/tensor/upsample_op_test.cc | 30 +- onnxruntime/test/providers/memcpy_test.cc | 8 +- .../test/providers/provider_test_utils.cc | 94 +- .../test/providers/provider_test_utils.h | 55 + .../providers/tensorrt/tensorrt_basic_test.cc | 6 +- .../test/python/onnx_backend_test_series.py | 13 +- .../test/python/onnxruntime_test_python.py | 29 +- .../python/onnxruntime_test_python_nuphar.py | 111 + .../test/server/unit_tests/converter_tests.cc | 2 +- onnxruntime/test/shared_lib/test_allocator.cc | 2 +- onnxruntime/test/shared_lib/test_inference.cc | 4 +- .../test/shared_lib/test_session_options.cc | 11 +- onnxruntime/test/testdata/CNTK/gen.py | 106 +- .../test_model_with_fullonnxdomain.onnx | 18 + onnxruntime/test/util/default_providers.cc | 11 +- .../test/util/include/default_providers.h | 2 +- requirements-dev.txt | 2 + .../c_cxx/fns_candy_style_transfer/README.md | 19 +- samples/c_cxx/imagenet/main.cc | 5 +- tools/ci_build/build.py | 64 +- tools/ci_build/gen_def.py | 12 +- .../azure-pipelines-py-packaging.yml | 4 +- .../c-api-packaging-pipelines.yml | 2 +- .../azure-pipelines/linux-ci-pipeline.yml | 2 +- .../linux-gpu-tensorrt-ci-pipeline.yml | 3 +- .../azure-pipelines/mac-ci-pipeline.yml | 2 +- .../azure-pipelines/nuget/templates/cpu.yml | 10 + .../azure-pipelines/nuget/templates/gpu.yml | 4 +- .../nuget/templates/test_win.yml | 2 +- .../azure-pipelines/templates/esrp_dll.yml | 2 +- .../azure-pipelines/templates/esrp_nuget.yml | 2 +- .../azure-pipelines/templates/win-ci.yml | 45 +- .../azure-pipelines/templates/win-x86-ci.yml | 40 +- .../windows-build-tools-setup-steps.yml | 34 +- .../azure-pipelines/win-ci-pipeline.yml | 2 +- .../azure-pipelines/win-gpu-ci-pipeline.yml | 4 +- .../win-gpu-tensorrt-ci-pipeline.yml | 67 +- .../win-ngraph-ci-pipeline.yml | 4 +- .../github/linux/copy_strip_binary.sh | 2 + .../linux/docker/Dockerfile.ubuntu_tensorrt | 8 +- .../linux/docker/scripts/install_onnx.sh | 2 +- .../github/windows/setup_env_cuda.bat | 2 +- tools/python/gen_opkernel_doc.py | 152 + 539 files changed, 51564 insertions(+), 4146 deletions(-) create mode 100644 cmake/onnxruntime_automl_featurizers.cmake create mode 100644 cmake/onnxruntime_nuphar_extern.cmake delete mode 100644 cmake/patches/ngraph/ngraph_fix_install_error.patch delete mode 100644 cmake/patches/ngraph/ngraph_fix_library_path.patch create mode 100644 cmake/patches/ngraph/ngraph_fix_mkldnn_missing_symbol.patch rename csharp/{src/Microsoft.ML.OnnxRuntime => }/OnnxRuntime.snk (100%) create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayTensorExtensions.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/Tensors/DenseTensor.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/NativeMemory.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.tt create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorExtensions.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.tt create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTemplate.ttinclude create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTests.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTestsBase.cs rename dockerfiles/{ => scripts}/install_common_deps.sh (81%) create mode 100644 docs/OperatorKernels.md create mode 100644 docs/execution_providers/Nuphar-ExecutionProvider.md delete mode 100644 include/onnxruntime/core/common/callback.h create mode 100644 include/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h create mode 100644 onnxruntime/automl_ops/automl_featurizers.h create mode 100644 onnxruntime/automl_ops/automl_types.cc create mode 100644 onnxruntime/automl_ops/automl_types.h create mode 100644 onnxruntime/automl_ops/cpu/datetime_transformer.cc create mode 100644 onnxruntime/automl_ops/cpu_automl_kernels.cc create mode 100644 onnxruntime/automl_ops/cpu_automl_kernels.h create mode 100644 onnxruntime/contrib_ops/cpu/matmul_integer16.cc create mode 100644 onnxruntime/contrib_ops/cpu/matmul_integer16.h create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp create mode 100644 onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp delete mode 100644 onnxruntime/core/common/task_thread_pool.h create mode 100644 onnxruntime/core/framework/callback.h create mode 100644 onnxruntime/core/graph/automl_ops/automl_defs.cc create mode 100644 onnxruntime/core/graph/automl_ops/automl_defs.h rename onnxruntime/core/mlas/lib/aarch64/{sgemma.s => SgemmKernelNeon.S} (95%) create mode 100644 onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm rename onnxruntime/core/mlas/lib/arm64/{sgemma.asm => SgemmKernelNeon.asm} (91%) create mode 100644 onnxruntime/core/mlas/lib/qgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S create mode 100644 onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.cc create mode 100644 onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/analysis.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/graph_stats.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.cc create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/shape_expr.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.cc create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.cc create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.h create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.cc create mode 100644 onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.h create mode 100644 onnxruntime/core/providers/nuphar/common/nuphar_settings.cc create mode 100644 onnxruntime/core/providers/nuphar/common/nuphar_settings.h create mode 100644 onnxruntime/core/providers/nuphar/common/nuphar_subgraph.h create mode 100644 onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc create mode 100644 onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h create mode 100644 onnxruntime/core/providers/nuphar/common/utils.cc create mode 100644 onnxruntime/core/providers/nuphar/common/utils.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/codegen_manager.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/codegen_manager.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/func_info.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/func_info.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/initializer_info.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_handle.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/gemm.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/logsoftmax.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/matmul.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/quantize/matmul_integer.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/reduce_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/softmax.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/unary_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/slice.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/tile.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/analysis_schedule.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/ort_type_schedule.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/partial_schedule.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.cpp create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_base.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_utilities.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_utilities.h create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize_schedule.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tvm_rule_schedule.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/x86_target_info.cc create mode 100644 onnxruntime/core/providers/nuphar/compiler/x86/x86_target_info.h create mode 100644 onnxruntime/core/providers/nuphar/extern/igemv_avx2.cc create mode 100644 onnxruntime/core/providers/nuphar/extern/igemv_avx2.h create mode 100644 onnxruntime/core/providers/nuphar/extern/igemv_mkl.cc create mode 100644 onnxruntime/core/providers/nuphar/extern/igemv_mkl.h create mode 100644 onnxruntime/core/providers/nuphar/kernel.cc create mode 100644 onnxruntime/core/providers/nuphar/kernel.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/halide_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/halide_ops.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/logsoftmax.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/logsoftmax.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/matmul_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/matmul_ops.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/reduce_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/reduce_ops.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/softmax.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/softmax.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/softmax_internal.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/softmax_internal.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/unary_ops.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/math/unary_ops.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/quantize/imatmul16_extern.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/quantize/imatmul16_extern.h create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/quantize/imatmul_extern.cc create mode 100644 onnxruntime/core/providers/nuphar/mti_x86/quantize/imatmul_extern.h create mode 100644 onnxruntime/core/providers/nuphar/nuphar_execution_provider.cc create mode 100644 onnxruntime/core/providers/nuphar/nuphar_execution_provider.h create mode 100644 onnxruntime/core/providers/nuphar/nuphar_provider_factory.cc create mode 100644 onnxruntime/core/providers/nuphar/partition/graph_partitioner.cc create mode 100644 onnxruntime/core/providers/nuphar/partition/graph_partitioner.h create mode 100644 onnxruntime/core/providers/nuphar/partition/partitioner.cc create mode 100644 onnxruntime/core/providers/nuphar/partition/partitioner.h create mode 100644 onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.cc create mode 100644 onnxruntime/core/providers/nuphar/partition/subgraph_partitioner.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/compute_ctx.cc create mode 100644 onnxruntime/core/providers/nuphar/runtime/compute_ctx.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/control_flow/loop_exec_ctx.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/control_flow/scan_exec_ctx.cc create mode 100644 onnxruntime/core/providers/nuphar/runtime/control_flow/scan_exec_ctx.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/exec_block.cc create mode 100644 onnxruntime/core/providers/nuphar/runtime/exec_block.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/handle.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/sequential/basic.cc create mode 100644 onnxruntime/core/providers/nuphar/runtime/sequential/basic.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/sequential/loop.cc create mode 100644 onnxruntime/core/providers/nuphar/runtime/sequential/loop.h create mode 100644 onnxruntime/core/providers/nuphar/runtime/utils.h create mode 100644 onnxruntime/core/providers/nuphar/scripts/README.md create mode 100644 onnxruntime/core/providers/nuphar/scripts/cntk_converter.py create mode 100644 onnxruntime/core/providers/nuphar/scripts/create_shared.cmd create mode 100644 onnxruntime/core/providers/nuphar/scripts/create_shared.sh create mode 100644 onnxruntime/core/providers/nuphar/scripts/model_editor.py create mode 100644 onnxruntime/core/providers/nuphar/scripts/model_quantizer.py create mode 100644 onnxruntime/core/providers/nuphar/scripts/node_factory.py create mode 100644 onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py create mode 100644 onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py create mode 100644 onnxruntime/core/providers/nuphar/symbols.txt delete mode 100755 onnxruntime/core/providers/tensorrt/tensorrt_allocator.h create mode 100644 onnxruntime/core/util/gemmlowp_common.cc create mode 100644 onnxruntime/core/util/gemmlowp_common.h create mode 100644 onnxruntime/core/util/qmath.cc create mode 100644 onnxruntime/core/util/qmath.h create mode 100644 onnxruntime/test/automl_ops/datetimetransformer_test.cc create mode 100644 onnxruntime/test/contrib_ops/matmul_integer16_test.cc rename onnxruntime/test/{shared_lib => framework}/test_tensor_loader.cc (59%) create mode 100644 onnxruntime/test/onnx/callback.cc create mode 100644 onnxruntime/test/onnx/callback.h create mode 100644 onnxruntime/test/onnx/mem_buffer.h delete mode 100644 onnxruntime/test/onnx/microbenchmark/model_init.cc create mode 100644 onnxruntime/test/onnx/tensorprotoutils.cc create mode 100644 onnxruntime/test/onnx/tensorprotoutils.h create mode 100644 onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc create mode 100644 onnxruntime/test/python/onnxruntime_test_python_nuphar.py create mode 100644 onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx create mode 100644 tools/python/gen_opkernel_doc.py diff --git a/.gitmodules b/.gitmodules index 47dc9124d74dd..6eb38ef853cab 100644 --- a/.gitmodules +++ b/.gitmodules @@ -25,10 +25,6 @@ [submodule "cmake/external/re2"] path = cmake/external/re2 url = https://github.com/google/re2.git -[submodule "cmake/external/onnx-tensorrt"] - path = cmake/external/onnx-tensorrt - url = https://github.com/onnx/onnx-tensorrt.git - branch = v5.0 [submodule "cmake/external/eigen"] path = cmake/external/eigen url = https://github.com/eigenteam/eigen-git-mirror.git @@ -41,3 +37,7 @@ [submodule "cmake/external/spdlog"] path = cmake/external/spdlog url = https://github.com/gabime/spdlog.git +[submodule "cmake/external/onnx-tensorrt"] + path = cmake/external/onnx-tensorrt + url = https://github.com/onnx/onnx-tensorrt.git + branch = 5.1 diff --git a/BUILD.md b/BUILD.md index f4d1650ee03ad..f1bd922639d86 100644 --- a/BUILD.md +++ b/BUILD.md @@ -1,37 +1,9 @@ -# Build ONNX Runtime -Dockerfiles are available [here](https://github.com/microsoft/onnxruntime/tree/master/tools/ci_build/github/linux/docker) to help you get started. +# Building ONNX Runtime - Getting Started +*Dockerfiles are available [here](https://github.com/microsoft/onnxruntime/tree/master/tools/ci_build/github/linux/docker) to help you get started.* -## Supported architectures +*Pre-built packages are available at the locations indicated [here](https://github.com/microsoft/onnxruntime#official-builds).* -| | x86_32 | x86_64 | ARM32v7 | ARM64 | -|-----------|:------------:|:------------:|:------------:|:------------:| -|Windows | YES | YES | YES | YES | -|Linux | YES | YES | YES | YES | -|Mac OS X | NO | YES | NO | NO | - -## Supported dev environments - -| OS | Supports CPU | Supports GPU| Notes | -|-------------|:------------:|:------------:|------------------------------------| -|Windows 10 | YES | YES | VS2019 through the latest VS2015 are supported | -|Windows 10
Subsystem for Linux | YES | NO | | -|Ubuntu 16.x | YES | YES | Also supported on ARM32v7 (experimental) | - -* Red Hat Enterprise Linux and CentOS are not supported. -* Other version of Ubuntu might work but we don't support them officially. -* GCC 4.x and below are not supported. - -OS/Compiler Matrix: - -| OS/Compiler | Supports VC | Supports GCC | -|-------------|:------------:|:----------------:| -|Windows 10 | YES | Not tested | -|Linux | NO | YES(gcc>=5.0) | - -ONNX Runtime python binding only supports Python 3.5, 3.6 and 3.7. - -## Getting Started -You may either get a prebuilt onnxruntime from nuget.org, or do it yourself using the following steps: +## To build the baseline CPU version of ONNX Runtime from source: 1. Checkout the source tree: ``` git clone --recursive https://github.com/Microsoft/onnxruntime @@ -39,7 +11,8 @@ You may either get a prebuilt onnxruntime from nuget.org, or do it yourself usin ``` 2. Install cmake-3.13 or better from https://cmake.org/download/. -On Windows: +**On Windows:** + 3. (optional) Install protobuf 3.6.1 from source code (cmake/external/protobuf). CMake flag protobuf\_BUILD\_SHARED\_LIBS must be turned OFF. After the installation, you should have the 'protoc' executable in your PATH. 4. (optional) Install onnx from source code (cmake/external/onnx) ``` @@ -49,7 +22,10 @@ On Windows: ``` 5. Run `build.bat --config RelWithDebInfo --build_shared_lib --parallel`. -On Linux: +*Note: The default Windows CMake Generator is Visual Studio 2017, but you can also use the newer Visual Studio 2019 by passing `--cmake_generator "Visual Studio 16 2019"` to build.bat.* + +**On Linux:** + 3. (optional) Install protobuf 3.6.1 from source code (cmake/external/protobuf). CMake flag protobuf\_BUILD\_SHARED\_LIBS must be turned ON. After the installation, you should have the 'protoc' executable in your PATH. It is recommended to run `ldconfig` to make sure protobuf libraries are found. 4. If you installed your protobuf in a non standard location it would be helpful to set the following env var:`export CMAKE_ARGS="-DONNX_CUSTOM_PROTOC_EXECUTABLE=full path to protoc"` so ONNX build can find it. Also run `ldconfig ` so the linker can find protobuf libraries. 5. (optional) Install onnx from source code (cmake/external/onnx) @@ -62,46 +38,120 @@ On Linux: The build script runs all unit tests by default (for native builds and skips tests by default for cross-compiled builds). +--- + +# Supported architectures and build environments + +## Architectures + +| | x86_32 | x86_64 | ARM32v7 | ARM64 | +|-----------|:------------:|:------------:|:------------:|:------------:| +|Windows | YES | YES | YES | YES | +|Linux | YES | YES | YES | YES | +|Mac OS X | NO | YES | NO | NO | + +## Environments + +| OS | Supports CPU | Supports GPU| Notes | +|-------------|:------------:|:------------:|------------------------------------| +|Windows 10 | YES | YES | VS2019 through the latest VS2015 are supported | +|Windows 10
Subsystem for Linux | YES | NO | | +|Ubuntu 16.x | YES | YES | Also supported on ARM32v7 (experimental) | + +* Red Hat Enterprise Linux and CentOS are not supported. +* Other version of Ubuntu might work but we don't support them officially. +* GCC 4.x and below are not supported. + +### OS/Compiler Matrix: + +| OS/Compiler | Supports VC | Supports GCC | +|-------------|:------------:|:----------------:| +|Windows 10 | YES | Not tested | +|Linux | NO | YES(gcc>=5.0) | + +ONNX Runtime Python bindings support Python 3.5, 3.6 and 3.7. + +--- + +# Additional Build Instructions The complete list of build options can be found by running `./build.sh (or ./build.bat) --help` -## Build x86 - - For Windows, just add --x86 argument when launching build.bat - - For Linux, it must be built out of a x86 os, --x86 argument also needs be specified to build.sh +* [Docker on Linux](#Docker-on-Linux) +* [ONNX Runtime Server (Linux)](#Build-ONNX-Runtime-Server-on-Linux) -## Build ONNX Runtime Server on Linux +**Execution Providers** +* [NVIDIA CUDA](#CUDA) +* [NVIDIA TensorRT](#TensorRT) +* [Intel MKL-DNN/MKL-ML](#MKLDNN-and-MKLML) +* [Intel nGraph](#nGraph) +* [Intel OpenVINO](#openvino) +* [Android NNAPI](#Android) +* [Nuphar](#Nuphar) + +**Options** +* [OpenMP](#OpenMP) +* [OpenBLAS](#OpenBLAS) + +**Architectures** +* [x86](#x86) +* [ARM](#ARM) + +--- +## Docker on Linux +Install Docker: `https://docs.docker.com/install/` + +**CPU** +``` +cd tools/ci_build/github/linux/docker +docker build -t onnxruntime_dev --build-arg OS_VERSION=16.04 -f Dockerfile.ubuntu . +docker run --rm -it onnxruntime_dev /bin/bash +``` + +**GPU** +If you need GPU support, please also install: +1. nvidia driver. Before doing this please add `nomodeset rd.driver.blacklist=nouveau` to your linux [kernel boot parameters](https://www.kernel.org/doc/html/v4.17/admin-guide/kernel-parameters.html). +2. nvidia-docker2: [Install doc](`https://github.com/NVIDIA/nvidia-docker/wiki/Installation-(version-2.0)`) + +To test if your nvidia-docker works: +``` +docker run --runtime=nvidia --rm nvidia/cuda nvidia-smi +``` + +Then build a docker image. We provided a sample for use: +``` +cd tools/ci_build/github/linux/docker +docker build -t cuda_dev -f Dockerfile.ubuntu_gpu . +``` + +Then run it +``` +./tools/ci_build/github/linux/run_dockerbuild.sh +``` + +--- +## Build ONNX Runtime Server on Linux +Read more about ONNX Runtime Server [here](https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Server_Usage.md) 1. ONNX Runtime server (and only the server) requires you to have Go installed to build, due to building BoringSSL. See https://golang.org/doc/install for installation instructions. 2. In the ONNX Runtime root folder, run `./build.sh --config RelWithDebInfo --build_server --use_openmp --parallel` 3. ONNX Runtime Server supports sending log to [rsyslog](https://www.rsyslog.com/) daemon. To enable it, please build with an additional parameter: `--cmake_extra_defines onnxruntime_USE_SYSLOG=1`. The build command will look like this: `./build.sh --config RelWithDebInfo --build_server --use_openmp --parallel --cmake_extra_defines onnxruntime_USE_SYSLOG=1` +--- -## Build/Test Flavors for CI - -### CI Build Environments - -| Build Job Name | Environment | Dependency | Test Coverage | Scripts | -|--------------------|---------------------|---------------------------------|--------------------------|------------------------------------------| -| Linux_CI_Dev | Ubuntu 16.04 | python=3.5 | Unit tests; ONNXModelZoo | [script](tools/ci_build/github/linux/run_build.sh) | -| Linux_CI_GPU_Dev | Ubuntu 16.04 | python=3.5; nvidia-docker | Unit tests; ONNXModelZoo | [script](tools/ci_build/github/linux/run_build.sh) | -| Windows_CI_Dev | Windows Server 2016 | python=3.5 | Unit tests; ONNXModelZoo | [script](build.bat) | -| Windows_CI_GPU_Dev | Windows Server 2016 | cuda=9.1; cudnn=7.1; python=3.5 | Unit tests; ONNXModelZoo | [script](build.bat) | - -## Additional Build Flavors -The complete list of build flavors can be seen by running `./build.sh --help` or `./build.bat --help`. Here are some common flavors. +## Execution Providers -### Windows CMake Generator -The default generator on Windows is Visual Studio 2017, but you can also use the newer Visual Studio 2019 by passing `--cmake_generator "Visual Studio 16 2019"` to build.bat. +### CUDA +For Linux, please use [this Dockerfile](https://github.com/microsoft/onnxruntime/blob/master/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu) and refer to instructions above for [building with Docker on Linux](#Docker-on-Linux) -### Windows CUDA Build -ONNX Runtime supports CUDA builds. You will need to download and install [CUDA](https://developer.nvidia.com/cuda-toolkit) and [CUDNN](https://developer.nvidia.com/cudnn). +ONNX Runtime supports CUDA builds. You will need to download and install [CUDA](https://developer.nvidia.com/cuda-toolkit) and [cuDNN](https://developer.nvidia.com/cudnn). -ONNX Runtime is built and tested with CUDA 9.1 and CUDNN 7.1 using the Visual Studio 2017 14.11 toolset (i.e. Visual Studio 2017 v15.3). -CUDA versions from 9.1 up to 10.0, and CUDNN versions from 7.1 up to 7.4 should also work with Visual Studio 2017. +ONNX Runtime is built and tested with CUDA 10.0 and cuDNN 7.3 using the Visual Studio 2017 14.11 toolset (i.e. Visual Studio 2017 v15.3). +CUDA versions from 9.1 up to 10.1, and cuDNN versions from 7.1 up to 7.4 should also work with Visual Studio 2017. - The path to the CUDA installation must be provided via the CUDA_PATH environment variable, or the `--cuda_home parameter`. - - The path to the CUDNN installation (include the `cuda` folder in the path) must be provided via the CUDNN_PATH environment variable, or `--cudnn_home parameter`. The CUDNN path should contain `bin`, `include` and `lib` directories. - - The path to the CUDNN bin directory must be added to the PATH environment variable so that cudnn64_7.dll is found. + - The path to the cuDNN installation (include the `cuda` folder in the path) must be provided via the cuDNN_PATH environment variable, or `--cudnn_home parameter`. The cuDNN path should contain `bin`, `include` and `lib` directories. + - The path to the cuDNN bin directory must be added to the PATH environment variable so that cudnn64_7.dll is found. You can build with: @@ -110,7 +160,7 @@ You can build with: ./build.bat --use_cuda --cudnn_home --cuda_home (Windows) ``` -Depending on compatibility between the CUDA, CUDNN, and Visual Studio 2017 versions you are using, you may need to explicitly install an earlier version of the MSVC toolset. +Depending on compatibility between the CUDA, cuDNN, and Visual Studio 2017 versions you are using, you may need to explicitly install an earlier version of the MSVC toolset. - CUDA 10.0 is known to work with toolsets from 14.11 up to 14.16 (Visual Studio 2017 15.9), and should continue to work with future Visual Studio versions - https://devblogs.microsoft.com/cppblog/cuda-10-is-now-available-with-support-for-the-latest-visual-studio-2017-versions/ - CUDA 9.2 is known to work with the 14.11 MSVC toolset (Visual Studio 15.3 and 15.4) @@ -132,30 +182,38 @@ _Side note: If you have multiple versions of CUDA installed on a Windows machine e.g. C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\Common7\IDE\VC\VCTargets\BuildCustomizations\. If you want to build with an earlier version, you must temporarily remove the 'CUDA x.y.*' files for later versions from this directory._ -### MKL-DNN/MKLML -To build ONNX Runtime with MKL-DNN support, build it with `./build.sh --use_mkldnn` -To build ONNX Runtime using MKL-DNN built with dependency on MKL small libraries, build it with `./build.sh --use_mkldnn --use_mklml` - -### nGraph -ONNX runtime with nGraph as an execution provider (released as preview) can be built on Linux as follows : `./build.sh --use_ngraph`. Similarly, on Windows use `.\build.bat --use_ngraph`. +--- ### TensorRT -ONNX Runtime supports the TensorRT execution provider (released as preview). You will need to download and install [CUDA](https://developer.nvidia.com/cuda-toolkit), [CUDNN](https://developer.nvidia.com/cudnn) and [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download). +ONNX Runtime supports the TensorRT execution provider (released as preview). You will need to download and install [CUDA](https://developer.nvidia.com/cuda-toolkit), [cuDNN](https://developer.nvidia.com/cudnn) and [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download). -The TensorRT execution provider for ONNX Runtime is built and tested with CUDA 9.0/CUDA 10.0, CUDNN 7.1 and TensorRT 5.0.2.6. +The TensorRT execution provider for ONNX Runtime is built and tested with CUDA 9.0/CUDA 10.0, cuDNN 7.1 and TensorRT 5.0.2.6. - The path to the CUDA installation must be provided via the CUDA_PATH environment variable, or the `--cuda_home parameter`. The CUDA path should contain `bin`, `include` and `lib` directories. - The path to the CUDA `bin` directory must be added to the PATH environment variable so that `nvcc` is found. - - The path to the CUDNN installation (path to folder that contains libcudnn.so) must be provided via the CUDNN_PATH environment variable, or `--cudnn_home parameter`. + - The path to the cuDNN installation (path to folder that contains libcudnn.so) must be provided via the cuDNN_PATH environment variable, or `--cudnn_home parameter`. - The path to TensorRT installation must be provided via the `--tensorrt_home parameter`. You can build from source on Linux by using the following `cmd` from the onnxruntime directory: ``` -./build.sh --cudnn_home --cuda_home --use_tensorrt --tensorrt_home (Linux) - +./build.sh --cudnn_home --cuda_home --use_tensorrt --tensorrt_home (Linux) ``` -### OpenVINO Build + +--- + +### MKLDNN and MKLML +To build ONNX Runtime with MKL-DNN support, build it with `./build.sh --use_mkldnn` +To build ONNX Runtime using MKL-DNN built with dependency on MKL small libraries, build it with `./build.sh --use_mkldnn --use_mklml` + +--- + +### nGraph +ONNX runtime with nGraph as an execution provider (released as preview) can be built on Linux as follows : `./build.sh --use_ngraph`. Similarly, on Windows use `.\build.bat --use_ngraph` + +--- + +### OpenVINO ONNX Runtime supports OpenVINO Execution Provider to enable deep learning inference using Intel® OpenVINOTM Toolkit. This execution provider supports several Intel hardware device types - CPU, integrated GPU, Intel® MovidiusTM VPUs and Intel® Vision accelerator Design with 8 Intel MovidiusTM MyriadX VPUs. @@ -194,58 +252,97 @@ The OpenVINO Execution Provider can be built using the following commands: | VAD-M_FP16 | Intel® Vision Accelerator Design based on 8 MovidiusTM MyriadX VPUs | For more information on OpenVINO Execution Provider's ONNX Layer support, Topology support, and Intel hardware enabled, please refer to the document OpenVINO-ExecutionProvider.md in $onnxruntime_root/docs/execution_providers + +--- -### OpenBLAS -#### Windows -Instructions how to build OpenBLAS for windows can be found here https://github.com/xianyi/OpenBLAS/wiki/How-to-use-OpenBLAS-in-Microsoft-Visual-Studio#build-openblas-for-universal-windows-platform. +### Android -Once you have the OpenBLAS binaries, build ONNX Runtime with `./build.bat --use_openblas` +#### Cross compiling on Linux -#### Linux -For Linux (e.g. Ubuntu 16.04), install libopenblas-dev package -`sudo apt-get install libopenblas-dev` and build with `./build.sh --use_openblas` +1. Get Android NDK from https://developer.android.com/ndk/downloads. Please unzip it after downloading. -### OpenMP -``` -./build.sh --use_openmp (for Linux) -./build.bat --use_openmp (for Windows) -``` +2. Get a pre-compiled protoc: -### Build with Docker on Linux -Install Docker: `https://docs.docker.com/install/` + You may get it from https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip. Please unzip it after downloading. + +3. Denote the unzip destination in step 1 as $ANDROID_NDK, append `-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DONNX_CUSTOM_PROTOC_EXECUTABLE=path/to/protoc` to your cmake args, run cmake and make to build it. + +Note: For 32-bit devices, replace `-DANDROID_ABI=arm64-v8a` to `-DANDROID_ABI=armeabi-v7a`. + +--- -#### CPU +### Nuphar +ONNX Runtime supports Nuphar execution provider (released as preview). It is an execution provider built on top of [TVM](https://github.com/dmlc/tvm) and [LLVM](https://llvm.org). Currently it targets to X64 CPU. + +The Nuphar execution provider for ONNX Runtime is built and tested with LLVM 6.0.1. Because of TVM's requirement when building with LLVM, you need to build LLVM from source: + +Window with Visual Studio 2017: (Note here builds release flavor. Debug build of LLVM would be needed to build with Debug flavor of ONNX Runtime) ``` -cd tools/ci_build/github/linux/docker -docker build -t onnxruntime_dev --build-arg OS_VERSION=16.04 -f Dockerfile.ubuntu . -docker run --rm -it onnxruntime_dev /bin/bash +REM download llvm source code 6.0.1 and unzip to \llvm\source\path, then install to \llvm\install\path +cd \llvm\source\path +mkdir build +cd build +cmake .. -G "Visual Studio 15 2017 Win64" -DLLVM_TARGETS_TO_BUILD=X86 +msbuild llvm.sln /maxcpucount /p:Configuration=Release /p:Platform=x64 +cmake -DCMAKE_INSTALL_PREFIX=\llvm\install\path -DBUILD_TYPE=Release -P cmake_install.cmake ``` -#### GPU -If you need GPU support, please also install: -1. nvidia driver. Before doing this please add `nomodeset rd.driver.blacklist=nouveau` to your linux [kernel boot parameters](https://www.kernel.org/doc/html/v4.17/admin-guide/kernel-parameters.html). -2. nvidia-docker2: [Install doc](`https://github.com/NVIDIA/nvidia-docker/wiki/Installation-(version-2.0)`) +Linux: +``` +# download llvm source code 6.0.1 and unzip to /llvm/source/path, then install to /llvm/install/path +cd /llvm/source/path +mkdir build +cd build +cmake .. -DLLVM_TARGETS_TO_BUILD=X86 -DCMAKE_BUILD_TYPE=Release +cmake --build. +cmake -DCMAKE_INSTALL_PREFIX=/llvm/install/path -DBUILD_TYPE=Release -P cmake_install.cmake +``` -To test if your nvidia-docker works: +Then you can build from source by using following command from the onnxruntime directory: +Windows: ``` -docker run --runtime=nvidia --rm nvidia/cuda nvidia-smi +build.bat --use_tvm --use_llvm --llvm_path=\llvm\install\path\lib\cmake\llvm --use_mklml --use_nuphar --build_shared_lib --build_csharp --enable_pybind --config=Release ``` -Then build a docker image. We provided a sample for use: +Linux: ``` -cd tools/ci_build/github/linux/docker -docker build -t cuda_dev -f Dockerfile.ubuntu_gpu . +./build.sh --use_tvm --use_llvm --llvm_path=/llvm/install/path/lib/cmake/llvm --use_mklml --use_nuphar --build_shared_lib --build_csharp --enable_pybind --config=Release ``` -Then run it +--- + +## Options +### OpenMP ``` -./tools/ci_build/github/linux/run_dockerbuild.sh +./build.sh --use_openmp (for Linux) +./build.bat --use_openmp (for Windows) ``` -## ARM Builds +--- + +### OpenBLAS +**Windows** +Instructions how to build OpenBLAS for windows can be found here https://github.com/xianyi/OpenBLAS/wiki/How-to-use-OpenBLAS-in-Microsoft-Visual-Studio#build-openblas-for-universal-windows-platform. + +Once you have the OpenBLAS binaries, build ONNX Runtime with `./build.bat --use_openblas` + +**Linux** +For Linux (e.g. Ubuntu 16.04), install libopenblas-dev package +`sudo apt-get install libopenblas-dev` and build with `./build.sh --use_openblas` + +--- + +## Architectures +### x86 + - For Windows, just add --x86 argument when launching build.bat + - For Linux, it must be built out of a x86 os, --x86 argument also needs be specified to build.sh + +--- + +### ARM We have experimental support for Linux ARM builds. Windows on ARM is well tested. -### Cross compiling for ARM with Docker (Linux/Windows - FASTER, RECOMMENDED) +#### Cross compiling for ARM with Docker (Linux/Windows - FASTER, RECOMMENDED) This method allows you to compile using a desktop or cloud VM. This is much faster than compiling natively and avoids out-of-memory issues that may be encountered when on lower-powered ARM devices. The resulting ONNX Runtime Python wheel (.whl) file is then deployed to an ARM device where it can be invoked in Python 3 scripts. The Dockerfile used in these instructions specifically targets Raspberry Pi 3/3+ running Raspbian Stretch. The same approach should work for other ARM devices, but may require some changes to the Dockerfile such as choosing a different base image (Line 0: `FROM ...`). @@ -296,7 +393,7 @@ The Dockerfile used in these instructions specifically targets Raspberry Pi 3/3+ ``` 10. Test installation by following the instructions [here](https://microsoft.github.io/onnxruntime/) -### Cross compiling on Linux (without Docker) +#### Cross compiling on Linux (without Docker) 1. Get the corresponding toolchain. For example, if your device is Raspberry Pi and the device os is Ubuntu 16.04, you may use gcc-linaro-6.3.1 from [https://releases.linaro.org/components/toolchain/binaries](https://releases.linaro.org/components/toolchain/binaries) 2. Setup env vars ```bash @@ -321,8 +418,7 @@ The Dockerfile used in these instructions specifically targets Raspberry Pi 3/3+ ``` 6. Append `-DONNX_CUSTOM_PROTOC_EXECUTABLE=/path/to/protoc -DCMAKE_TOOLCHAIN_FILE=path/to/tool.cmake` to your cmake args, run cmake and make to build it. - -### Native compiling on Linux ARM device (SLOWER) +#### Native compiling on Linux ARM device (SLOWER) Docker build runs on a Raspberry Pi 3B with Raspbian Stretch Lite OS (Desktop version will run out memory when linking the .so file) will take 8-9 hours in total. ```bash sudo apt-get update @@ -374,26 +470,10 @@ ls -l /code/onnxruntime/build/Linux/MinSizeRel/*.so ls -l /code/onnxruntime/build/Linux/MinSizeRel/dist/*.whl ``` -### Cross compiling on Windows -#### Using Visual C++ compilers +#### Cross compiling on Windows +**Using Visual C++ compilers** 1. Download and install Visual C++ compilers and libraries for ARM(64). If you have Visual Studio installed, please use the Visual Studio Installer (look under the section `Individual components` after choosing to `modify` Visual Studio) to download and install the corresponding ARM(64) compilers and libraries. 2. Use `build.bat` and specify `--arm` or `--arm64` as the build option to start building. Preferably use `Developer Command Prompt for VS` or make sure all the installed cross-compilers are findable from the command prompt being used to build using the PATH environmant variable. -### Using other compilers -(TODO) - -## Android Builds - -### Cross compiling on Linux - -1. Get Android NDK from https://developer.android.com/ndk/downloads. Please unzip it after downloading. - -2. Get a pre-compiled protoc: - - You may get it from https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip. Please unzip it after downloading. - -3. Denote the unzip destination in step 1 as $ANDROID_NDK, append `-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DONNX_CUSTOM_PROTOC_EXECUTABLE=path/to/protoc` to your cmake args, run cmake and make to build it. - -Note: For 32-bit devices, replace `-DANDROID_ABI=arm64-v8a` to `-DANDROID_ABI=armeabi-v7a`. diff --git a/README.md b/README.md index b1a715dbfe95d..5eed6744af919 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,19 @@ [ONNX](https://onnx.ai) is an interoperable format for machine learning models supported by various ML and DNN frameworks and tools. The universal format makes it easier to interoperate between frameworks and maximize the reach of hardware optimization investments. *** +**[Key Features](#key-features)** + **Setup** * [Installation](#installation) * [APIs and Official Binaries](#apis-and-official-builds) * [Building from Source](#building-from-source) -**Getting Started** +**Usage** * [Getting ONNX Models](#getting-onnx-models) * [Deploying ONNX Runtime](#deploying-onnx-runtime) -* [Examples and Tutorials](#examples-and-tutorials) +* [Performance Tuning](#performance-tuning) + +**[Examples and Tutorials](#examples-and-tutorials)** **More Info** * [Technical Design Details](#technical-design-details) @@ -29,39 +33,43 @@ **[License](#license)** *** -## Key Features -### Run any ONNX model +# Key Features +## Run any ONNX model ONNX Runtime provides comprehensive support of the ONNX spec and can be used to run all models based on ONNX v1.2.1 and higher. See version compatibility details [here](https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md). -*Note: Some operators not supported in the current ONNX version may be available as a [Contrib Operator](https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md)* - **Traditional ML support** In addition to DNN models, ONNX Runtime fully supports the [ONNX-ML profile](https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md) of the ONNX spec for traditional ML scenarios. -### High Performance +For the full set of operators and types supported, please see [operator documentation](https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md) + +*Note: Some operators not supported in the current ONNX version may be available as a [Contrib Operator](https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md)* + + +## High Performance ONNX Runtime supports both CPU and GPU. Using various graph optimizations and accelerators, ONNX Runtime can provide lower latency compared to other runtimes for faster end-to-end customer experiences and minimized machine utilization costs. Currently ONNX Runtime supports the following accelerators: -* CPU - * MLAS (Microsoft Linear Algebra Subprograms) - * MKL-DNN - * MKL-ML - * [Intel nGraph](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/nGraph-ExecutionProvider.md) -* GPU - * CUDA - * [TensorRT](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/TensorRT-ExecutionProvider.md) +* MLAS (Microsoft Linear Algebra Subprograms) +* [MKL-DNN](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/MKL-DNN-ExecutionProvider.md) - [subgraph optimization](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/MKL-DNN-Subgraphs.md) +* MKL-ML +* [Intel nGraph](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/nGraph-ExecutionProvider.md) +* CUDA +* [TensorRT](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/TensorRT-ExecutionProvider.md) +* [OpenVINO](https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/OpenVINO-ExecutionProvider.md) +* [Nuphar](docs/execution_providers/Nuphar-ExecutionProvider.md) -Not all variations are supported in the [official release builds](#apis-and-official-builds), but can be built from source following [these instructions](https://github.com/Microsoft/onnxruntime/blob/master/BUILD.md). +Not all variations are supported in the [official release builds](#apis-and-official-builds), but can be built from source following [these instructions](https://github.com/Microsoft/onnxruntime/blob/master/BUILD.md). Find Dockerfiles [here](https://github.com/microsoft/onnxruntime/tree/master/dockerfiles). We are continuously working to integrate new execution providers for further improvements in latency and efficiency. If you are interested in contributing a new execution provider, please see [this page](docs/AddingExecutionProvider.md). -### Cross Platform + +## Cross Platform [API documentation and package installation](https://github.com/microsoft/onnxruntime#installation) ONNX Runtime is available for Linux, Windows, Mac with Python, C#, and C APIs, with more to come! If you have specific scenarios that are not currently supported, please share your suggestions and scenario details via [Github Issues](https://github.com/microsoft/onnxruntime/issues). - +*** # Installation **Quick Start:** The [ONNX-Ecosystem Docker container image](https://github.com/onnx/onnx-docker/tree/master/onnx-ecosystem) is available on Dockerhub and includes ONNX Runtime (CPU, Python), dependencies, tools to convert from various frameworks, and Jupyter notebooks to help get started. @@ -80,7 +88,7 @@ Additional dockerfiles for some features can be found [here](https://github.com/ |---|:---|:---|:---| | **Python** | **[pypi: onnxruntime](https://pypi.org/project/onnxruntime)**

Windows (x64)
Linux (x64)
Mac OS X (x64) | -- | **[pypi: onnxruntime-gpu](https://pypi.org/project/onnxruntime-gpu)**

Windows (x64)
Linux (x64) | | **C#** | **[Nuget: Microsoft.ML.OnnxRuntime](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime/)**

Windows (x64, x86)
Linux (x64, x86)
Mac OS X (x64) | **[Nuget: Microsoft.ML.OnnxRuntime.MKLML](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.MKLML/)**

Windows (x64)
Linux (x64)
Mac OS X (x64) | **[Nuget: Microsoft.ML.OnnxRuntime.Gpu](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.Gpu/)**

Windows (x64)
Linux (x64) | -| **C** | **[Nuget: Microsoft.ML.OnnxRuntime](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime)**

**[.zip, .tgz](https://aka.ms/onnxruntime-release)**

Windows (x64, x86)
Linux (x64, x86)
Mac OS X (x64 | **[Nuget: Microsoft.ML.OnnxRuntime.MKLML](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.MKLML/)**

Windows (x64)
Linux (x64)
Mac OS X (x64) | **[Nuget: Microsoft.ML.OnnxRuntime.Gpu](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.Gpu/)**

**[.zip, .tgz](https://aka.ms/onnxruntime-release)**

Windows (x64)
Linux (x64) | +| **C/C++ wrapper** | **[Nuget: Microsoft.ML.OnnxRuntime](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime)**

**[.zip, .tgz](https://aka.ms/onnxruntime-release)**

Windows (x64, x86)
Linux (x64, x86)
Mac OS X (x64) | **[Nuget: Microsoft.ML.OnnxRuntime.MKLML](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.MKLML/)**

Windows (x64)
Linux (x64)
Mac OS X (x64) | **[Nuget: Microsoft.ML.OnnxRuntime.Gpu](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.Gpu/)**

**[.zip, .tgz](https://aka.ms/onnxruntime-release)**

Windows (x64)
Linux (x64) | #### System Requirements (pre-requisite dependencies) * ONNX Runtime binaries in the CPU packages use OpenMP and depend on the library being available at runtime in the @@ -88,20 +96,26 @@ system. * For Windows, **OpenMP** support comes as part of VC runtime. It is also available as redist packages: [vc_redist.x64.exe](https://aka.ms/vs/15/release/vc_redist.x64.exe) and [vc_redist.x86.exe](https://aka.ms/vs/15/release/vc_redist.x86.exe) * For Linux, the system must have **libgomp.so.1** which can be installed using `apt-get install libgomp1`. -* GPU builds require the **CUDA 10.0 and cuDNN 7.3** runtime libraries being installed on the system. Older releases used 9.1/7.1 - please refer to [release notes](https://github.com/microsoft/onnxruntime/releases) for more details. -* Python binaries are compatible with **Python 3.5-3.7**. See [Python Dev Notes](https://github.com/microsoft/onnxruntime/blob/master/docs/Python_Dev_Notes.md) +* GPU builds require CUDA runtime libraries being installed on the system: + * Version: **CUDA 10.0** and **cuDNN 7.3** + * Linux Python packages require **CUDA 10.1** and **cuDNN 7.6** + * Older ONNX Runtime releases: used **CUDA 9.1** and **cuDNN 7.1** - please refer to [prior release notes](https://github.com/microsoft/onnxruntime/releases) for more details. +* Python binaries are compatible with **Python 3.5-3.7**. See [Python Dev Notes](https://github.com/microsoft/onnxruntime/blob/master/docs/Python_Dev_Notes.md). If using `pip` to be download the Python binaries, run `pip install --upgrade pip` prior to downloading. * Certain operators makes use of system locales. Installation of the **English language package** and configuring `en_US.UTF-8 locale` is required. * For Ubuntu install [language-pack-en package](https://packages.ubuntu.com/search?keywords=language-pack-en) * Run the following commands: `locale-gen en_US.UTF-8` `update-locale LANG=en_US.UTF-8` * Follow similar procedure to configure other locales on other platforms. - + ## Building from Source If additional build flavors are needed, please find instructions on building from source at [Build ONNX Runtime](BUILD.md). For production scenarios, it's strongly recommended to build from an [official release branch](https://github.com/microsoft/onnxruntime/releases). Dockerfiles are available [here](https://github.com/microsoft/onnxruntime/tree/faxu-doc-updates/tools/ci_build/github/linux/docker) to help you get started. +*** +# Usage + ## Getting ONNX Models * The [ONNX Model Zoo](https://github.com/onnx/models) has popular ready-to-use pre-trained models. * To export or convert a trained ONNX model trained from various frameworks, see [ONNX Tutorials](https://github.com/onnx/tutorials). Versioning comptability information can be found under [Versioning](docs/Versioning.md#tool-compatibility) @@ -115,8 +129,12 @@ ONNX Runtime can be deployed to the cloud for model inferencing using [Azure Mac **ONNX Runtime Server (beta)** is a hosted application for serving ONNX models using ONNX Runtime, providing a REST API for prediction. Usage details can be found [here](https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Server_Usage.md), and image installation instructions are [here](https://github.com/microsoft/onnxruntime/tree/master/dockerfiles#onnx-runtime-server-preview). -## Examples and Tutorials -### Python +## Performance Tuning +ONNX Runtime is open and extensible, supporting a broad set of configurations and execution providers for model acceleration. For performance tuning guidance, please see [this page](https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Perf_Tuning.md). + +*** +# Examples and Tutorials +## Python * [Basic Inferencing Sample](https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/simple_onnxruntime_inference.ipynb) * [Inferencing (Resnet50)](https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb) * [Inferencing samples](https://github.com/onnx/onnx-docker/tree/master/onnx-ecosystem/inference_demos) using [ONNX-Ecosystem Docker image](https://github.com/onnx/onnx-docker/tree/master/onnx-ecosystem) @@ -127,21 +145,29 @@ ONNX Runtime can be deployed to the cloud for model inferencing using [Azure Mac **Deployment with AzureML** -* Inferencing: [Inferencing Facial Expression Recognition](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-inference-facial-expression-recognition-deploy.ipynb), [Inferencing MNIST Handwritten Digits](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-inference-mnist-deploy.ipynb), [ Resnet50 Image Classification](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-modelzoo-aml-deploy-resnet50.ipynb), [TinyYolo](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-convert-aml-deploy-tinyyolo.ipynb) -* [Train and Inference MNIST from Pytorch](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-train-pytorch-aml-deploy-mnist.ipynb) -* [FER+ on Azure Kubernetes Service with TensorRT](https://github.com/microsoft/onnxruntime/blob/master/docs/python/notebooks/onnx-inference-byoc-gpu-cpu-aks.ipynb) - - -### C# +* Inferencing using [ONNX Model Zoo](https://github.com/onnx/models) models: + * [Facial Expression Recognition](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-inference-facial-expression-recognition-deploy.ipynb) + * [MNIST Handwritten Digits](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-inference-mnist-deploy.ipynb) + * [Resnet50 Image Classification](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-modelzoo-aml-deploy-resnet50.ipynb) +* Convert existing model for Inferencing: + * [TinyYolo](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-convert-aml-deploy-tinyyolo.ipynb) +* Train a model with PyTorch and Inferencing: + * [MNIST](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/deployment/onnx/onnx-train-pytorch-aml-deploy-mnist.ipynb) + +* Inferencing with TensorRT Execution Provider on GPU (AKS) + * [FER+](https://github.com/microsoft/onnxruntime/blob/master/docs/python/notebooks/onnx-inference-byoc-gpu-cpu-aks.ipynb) + + +## C# * [Inferencing Tutorial](https://github.com/microsoft/onnxruntime/blob/master/docs/CSharp_API.md#getting-started) -### C/C++ +## C/C++ * [Basic Inferencing (SqueezeNet) - C](https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp) * [Basic Inferencing (SqueezeNet) - C++](https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp) * [Inferencing (MNIST) - C++](https://github.com/microsoft/onnxruntime/tree/master/samples/c_cxx/MNIST) - +*** # Technical Design Details * [High level architectural design](docs/HighLevelDesign.md) * [Versioning](docs/Versioning.md) @@ -153,6 +179,7 @@ ONNX Runtime can be deployed to the cloud for model inferencing using [Azure Mac transform](include/onnxruntime/core/optimizer/graph_transformer.h) * [Add a new rewrite rule](include/onnxruntime/core/optimizer/rewrite_rule.h) +*** # Contribute We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md). @@ -163,6 +190,6 @@ For any feedback or to report a bug, please file a [GitHub Issue](https://github This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. - +*** # License [MIT License](LICENSE) diff --git a/cgmanifest.json b/cgmanifest.json index 2fd8a43254e26..410f8210e1c27 100644 --- a/cgmanifest.json +++ b/cgmanifest.json @@ -49,7 +49,7 @@ "component":{ "type":"git", "git": { - "commitHash": "65b8e0f9979fbade16e3becbdfa69c0764946f72", + "commitHash": "7d90796473295ca3cdf976ed772215c5980ad3e0", "repositoryUrl": "https://github.com/onnx/onnx.git" } } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 9ee7470b0b2b3..04ac0cb65aa5c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -50,9 +50,10 @@ option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_NSYNC "Build with NSYNC support. This option only takes effect on Linux" OFF) option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON) option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF) -option(onnxruntime_USE_MLAS "Use optimized blas library for GEMM and 2D Convolution" ON) option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF) option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF) +option(onnxruntime_USE_GEMMLOWP "Build with gemmlowp for quantized gemm" OFF) +option(onnxruntime_USE_AUTOML "Build AutoML support" ON) option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF) option(onnxruntime_USE_OPENBLAS "Use openblas" OFF) option(onnxruntime_DEV_MODE "Enable developer warnings and treat most of them as error." OFF) @@ -349,8 +350,13 @@ if (onnxruntime_USE_TVM) add_definitions(-DUSE_TVM) set(onnxruntime_tvm_libs onnxruntime_codegen_tvm) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm nnvm_compiler) - + # needs to link with stdc++fs in Linux + if(UNIX) + if (NOT APPLE) + set(FS_STDLIB stdc++fs) + endif() + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm nnvm_compiler ${FS_STDLIB}) list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm nnvm_compiler) endif() @@ -367,10 +373,6 @@ if (onnxruntime_RUN_ONNX_TESTS) add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS) endif() -if (onnxruntime_USE_MLAS) - add_definitions(-DUSE_MLAS) -endif() - #Adjust warning flags if (WIN32) add_definitions(-DPLATFORM_WINDOWS -DNOGDI -DNOMINMAX -D_USE_MATH_DEFINES) @@ -476,6 +478,10 @@ if (onnxruntime_USE_MKLDNN OR onnxruntime_USE_MKLML) include(mkldnn) endif() +if(onnxruntime_USE_GEMMLOWP) + add_definitions(-DUSE_GEMMLOWP=1) +endif() + if (onnxruntime_USE_MKLML) add_definitions(-DUSE_MKLML=1 -DUSE_MKLML_FOR_BLAS=1) if (WIN32 OR APPLE) @@ -646,6 +652,12 @@ include(onnxruntime_optimizer.cmake) include(onnxruntime_session.cmake) include(onnxruntime_mlas.cmake) +if(onnxruntime_USE_AUTOML) + add_definitions(-DMICROSOFT_AUTOML) + # Build shared featurizer library + include(onnxruntime_automl_featurizers.cmake) +endif() + if(WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES Shlwapi) list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 364ba88a891c8..e3a638cc7a183 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -11,6 +11,8 @@ if(WIN32) set(MKLDNN_SHARED_LIB mkldnn.dll) set(MKLDNN_IMPORT_LIB mkldnn.lib) if(onnxruntime_USE_MKLML) + # Windows-only updated MKLML binary which contains fix for thread cleanup hang. + set(MKLML_VERSION 2020.0.20190813) set(MKLML_SHARED_LIB mklml.dll) set(MKLML_IMPORT_LIB mklml.lib) set(IOMP5MD_SHARED_LIB libiomp5md.dll) @@ -59,15 +61,15 @@ if (onnxruntime_USE_MKLDNN) set(MKLDNN_DLL_PATH ${MKLDNN_LIB_DIR}/${MKLDNN_SHARED_LIB}) endif() set(MKLDNN_INCLUDE_DIR ${MKLDNN_INSTALL}/include) - set (MKLDNN_CMAKE_EXTRA_ARGS) + set(MKLDNN_CMAKE_EXTRA_ARGS) + set(MKLDNN_PATCH_COMMAND1 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/mem-patch.cmake.patch) + # discard prior changes due to patching in mkldnn source to unblock incremental builds. + set(MKLDNN_PATCH_DISCARD_COMMAND cd ${MKLDNN_SOURCE} && git checkout -- .) if(NOT onnxruntime_BUILD_FOR_NATIVE_MACHINE) # pre-v1.0 list(APPEND MKLDNN_CMAKE_EXTRA_ARGS "-DARCH_OPT_FLAGS=") # v1.0 list(APPEND MKLDNN_CMAKE_EXTRA_ARGS "-DMKLDNN_ARCH_OPT_FLAGS=") - set(MKLDNN_PATCH_COMMAND1 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/mem-patch.cmake.patch) - # discard prior changes due to patching in mkldnn source to unblock incremental builds. - set(MKLDNN_PATCH_DISCARD_COMMAND cd ${MKLDNN_SOURCE} && git checkout -- .) endif() ExternalProject_Add(project_mkldnn PREFIX mkl-dnn diff --git a/cmake/external/ngraph.cmake b/cmake/external/ngraph.cmake index 12d0b6e1431db..45aae7d44f512 100644 --- a/cmake/external/ngraph.cmake +++ b/cmake/external/ngraph.cmake @@ -11,7 +11,7 @@ set(ngraph_SRC ${CMAKE_CURRENT_BINARY_DIR}/ngraph/src/project_ngraph) set(prebuilt_ONNX_SOURCE_DIR "${PROJECT_SOURCE_DIR}/external/onnx") set(prebuilt_ONNX_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/onnx") set(ngraph_URL "https://github.com/NervanaSystems/ngraph.git") -set(ngraph_TAG "v0.18.1") +set(ngraph_TAG "v0.22.1") # Libraries for python package. if (WIN32) @@ -42,7 +42,7 @@ else() endif() # discard prior changes due to unblock incremental builds. -set(NGRAPH_PATCH_DISCARD_COMMAND cd ${ngraph_SRC} && git checkout -- .) +set(NGRAPH_PATCH_DISCARD_COMMAND cd ${ngraph_SRC} && git reset HEAD --hard && git clean -fx) if (MSVC) set(prebuilt_ONNX_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/onnx/${CMAKE_BUILD_TYPE}") @@ -54,12 +54,12 @@ if (MSVC) PREFIX ngraph GIT_REPOSITORY ${ngraph_URL} GIT_TAG ${ngraph_TAG} + GIT_CONFIG core.autocrlf=input PATCH_COMMAND ${NGRAPH_PATCH_DISCARD_COMMAND} COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_onnx.cmake ${ngraph_SRC}/cmake/external_onnx.cmake COMMAND git apply --ignore-space-change --ignore-whitespace ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_protobuf.patch - COMMAND git apply --ignore-space-change --ignore-whitespace ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_fix_install_error.patch - COMMAND git apply --ignore-space-change --ignore-whitespace ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_fix_library_path.patch COMMAND git apply --ignore-space-change --ignore-whitespace ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_fix_memory.patch + COMMAND git apply --ignore-space-change --ignore-whitespace ${PROJECT_SOURCE_DIR}/patches/ngraph/ngraph_fix_mkldnn_missing_symbol.patch CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DNGRAPH_DEX_ONLY=ON diff --git a/cmake/external/onnx b/cmake/external/onnx index 65b8e0f9979fb..7d90796473295 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 65b8e0f9979fbade16e3becbdfa69c0764946f72 +Subproject commit 7d90796473295ca3cdf976ed772215c5980ad3e0 diff --git a/cmake/external/onnx-tensorrt b/cmake/external/onnx-tensorrt index 3aa0a1cb41fae..6c37109733a9b 160000 --- a/cmake/external/onnx-tensorrt +++ b/cmake/external/onnx-tensorrt @@ -1 +1 @@ -Subproject commit 3aa0a1cb41fae88b7787b6289a729ed9046a18e4 +Subproject commit 6c37109733a9bbf8211f0ca78a85804cb376eca0 diff --git a/cmake/external/tvm b/cmake/external/tvm index fd4801612817f..b4bff71f36eca 160000 --- a/cmake/external/tvm +++ b/cmake/external/tvm @@ -1 +1 @@ -Subproject commit fd4801612817f96e890058656834deb925fc064a +Subproject commit b4bff71f36eca1e840dd280ba485cad186718844 diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 91508a8aa8f57..8a6bf402e1488 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -19,13 +19,13 @@ foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) list(APPEND SYMBOL_FILES "${ONNXRUNTIME_ROOT}/core/providers/${f}/symbols.txt") endforeach() -add_custom_command(OUTPUT ${SYMBOL_FILE} - COMMAND ${PYTHON_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py" --version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}" --config ${ONNXRUNTIME_PROVIDER_NAMES} --style=${OUTPUT_STYLE} --output ${SYMBOL_FILE} +add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c + COMMAND ${PYTHON_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py" --version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}" --config ${ONNXRUNTIME_PROVIDER_NAMES} --style=${OUTPUT_STYLE} --output ${SYMBOL_FILE} --output_source ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c DEPENDS ${SYMBOL_FILES} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) -add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE}) -add_library(onnxruntime SHARED ${onnxruntime_session_srcs}) +add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) +add_library(onnxruntime SHARED ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) set_target_properties(onnxruntime PROPERTIES VERSION ${ORT_VERSION}) add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT}) @@ -37,12 +37,8 @@ endif() if(UNIX) if (APPLE) - set(BEGIN_WHOLE_ARCHIVE -Xlinker -all_load) - set(END_WHOLE_ARCHIVE -Xlinker -noall_load) set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker -dead_strip") else() - set(BEGIN_WHOLE_ARCHIVE -Xlinker --whole-archive) - set(END_WHOLE_ARCHIVE -Xlinker --no-whole-archive) set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker --version-script=${SYMBOL_FILE} -Xlinker --no-undefined -Xlinker --gc-sections") endif() else() @@ -59,7 +55,7 @@ endif() #The BEGIN_WHOLE_ARCHIVE/END_WHOLE_ARCHIVE part should contain the implementations of all the C API functions target_link_libraries(onnxruntime PRIVATE - ${BEGIN_WHOLE_ARCHIVE} + onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_CUDA} ${PROVIDERS_MKLDNN} @@ -67,12 +63,12 @@ target_link_libraries(onnxruntime PRIVATE ${PROVIDERS_NNAPI} ${PROVIDERS_TENSORRT} ${PROVIDERS_OPENVINO} + ${PROVIDERS_NUPHAR} onnxruntime_optimizer onnxruntime_providers onnxruntime_util ${onnxruntime_tvm_libs} onnxruntime_framework - ${END_WHOLE_ARCHIVE} onnxruntime_graph onnxruntime_common onnxruntime_mlas diff --git a/cmake/onnxruntime_automl_featurizers.cmake b/cmake/onnxruntime_automl_featurizers.cmake new file mode 100644 index 0000000000000..daffe92842826 --- /dev/null +++ b/cmake/onnxruntime_automl_featurizers.cmake @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# This source code should not depend on the onnxruntime and may be built independently + +file(GLOB automl_featurizers_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/*.h" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/*.h" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/*.cpp" +) + +source_group(TREE ${ONNXRUNTIME_ROOT}/core/automl/ FILES ${onnxruntime_automl_featurizers_srcs}) + +add_library(automl_featurizers ${automl_featurizers_srcs}) + +target_include_directories(automl_featurizers PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +set_target_properties(automl_featurizers PROPERTIES FOLDER "AutoMLFeaturizers") + +# Individual featurizers unit tests added at bulk +file(GLOB automl_featurizers_tests_srcs + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/*.cpp" +) + +list(APPEND automl_featurizers_tests_srcs + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp" + "${ONNXRUNTIME_ROOT}/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp" +) + +add_executable(automl_featurizers_unittests ${automl_featurizers_tests_srcs}) +add_dependencies(automl_featurizers_unittests automl_featurizers) +target_link_libraries(automl_featurizers_unittests PRIVATE gtest automl_featurizers) +source_group(TREE ${ONNXRUNTIME_ROOT}/core/automl/ FILES ${automl_featurizers_tests_srcs}) +set_target_properties(automl_featurizers_unittests PROPERTIES FOLDER "AutoMLFeaturizers") +add_test(NAME automl_featurizers_unittests + COMMAND automl_featurizers_unittests + WORKING_DIRECTORY $ +) + + +if (WIN32) + # Add Code Analysis properties to enable C++ Core checks. Have to do it via a props file include. + set_target_properties(automl_featurizers PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/ConfigureVisualStudioCodeAnalysis.props) +endif() diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 0799ab9a6c79e..133ea4b60bf16 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -53,11 +53,12 @@ target_include_directories(onnxruntime_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR if(onnxruntime_USE_NSYNC) target_compile_definitions(onnxruntime_common PUBLIC USE_NSYNC) endif() -if(onnxruntime_USE_EIGEN_THREADPOOL) - target_include_directories(onnxruntime_common PRIVATE ${eigen_INCLUDE_DIRS}) - target_compile_definitions(onnxruntime_common PUBLIC USE_EIGEN_THREADPOOL) - add_dependencies(onnxruntime_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +target_include_directories(onnxruntime_common PUBLIC ${eigen_INCLUDE_DIRS}) +if(NOT onnxruntime_USE_OPENMP) + target_compile_definitions(onnxruntime_common PUBLIC EIGEN_USE_THREADS) endif() +add_dependencies(onnxruntime_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/common DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) set_target_properties(onnxruntime_common PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 366eadf680fff..4c05a3307bff0 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -14,6 +14,13 @@ if (onnxruntime_DISABLE_CONTRIB_OPS) ) endif() +if(NOT onnxruntime_USE_AUTOML) + list(REMOVE_ITEM onnxruntime_graph_src + "${ONNXRUNTIME_ROOT}/core/graph/automl_ops/*.h" + "${ONNXRUNTIME_ROOT}/core/graph/automl_ops/*.cc" + ) +endif() + file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/defs/*.cc" ) @@ -21,6 +28,7 @@ file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS add_library(onnxruntime_graph ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) add_dependencies(onnxruntime_graph onnx_proto gsl) onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common gsl onnx onnx_proto protobuf::libprotobuf) + target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT}) set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 619a4c3d08dc9..a29cd85a94f7e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -4,6 +4,7 @@ set(mlas_common_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/platform.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/threading.cpp + ${ONNXRUNTIME_ROOT}/core/mlas/lib/qgemm.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/sgemm.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/convolve.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/pooling.cpp @@ -16,12 +17,10 @@ set(mlas_common_srcs ) if(MSVC) - if(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM64") - - set(asm_filename ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/sgemma.asm) - set(pre_filename ${CMAKE_CURRENT_BINARY_DIR}/sgemma.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/sgemma.obj) + set(asm_filename ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/SgemmKernelNeon.asm) + set(pre_filename ${CMAKE_CURRENT_BINARY_DIR}/SgemmKernelNeon.i) + set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/SgemmKernelNeon.obj) if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(ARMASM_FLAGS "-g") @@ -36,20 +35,18 @@ if(MSVC) COMMAND armasm64.exe ${ARMASM_FLAGS} ${pre_filename} ${obj_filename} ) - set(mlas_platform_srcs ${obj_filename}) - elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp ) - elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "x64" OR CMAKE_GENERATOR MATCHES "Win64") - enable_language(ASM_MASM) set(mlas_platform_srcs + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelSse2.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelAvx.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelFma3.asm @@ -67,9 +64,7 @@ if(MSVC) ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/TanhKernelFma3.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm ) - else() - enable_language(ASM_MASM) set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") @@ -77,14 +72,13 @@ if(MSVC) set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/i386/sgemma.asm ) - endif() else() if (CMAKE_SYSTEM_NAME STREQUAL "Android") if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") set(ARM TRUE) elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM TRUE) # Android NDK fails to compile sgemma.s + set(ARM64 TRUE) elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") set(X86_64 TRUE) elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") @@ -95,8 +89,7 @@ else() COMMAND ${CMAKE_C_COMPILER} -dumpmachine OUTPUT_VARIABLE dumpmachine_output ERROR_QUIET - ) - + ) if(dumpmachine_output MATCHES "^arm.*") set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") @@ -108,39 +101,39 @@ else() endif() endif() - if (ARM) + if(ARM) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp - ) - elseif (ARM64) + ) + elseif(ARM64) enable_language(ASM) set(mlas_platform_srcs - ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/sgemma.s - ) - elseif (X86) + ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S + ) + elseif(X86) enable_language(ASM) set(mlas_platform_srcs_sse2 ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86/SgemmKernelSse2.S - ) + ) set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") set(mlas_platform_srcs_avx ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86/SgemmKernelAvx.S - ) + ) set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") set(mlas_platform_srcs ${mlas_platform_srcs_sse2} ${mlas_platform_srcs_avx} - ) - elseif (X86_64) + ) + elseif(X86_64) enable_language(ASM) - # The LLVM assmebler does not support the .arch directive to enable instruction + # The LLVM assembler does not support the .arch directive to enable instruction # set extensions and also doesn't support AVX-512F instructions without # turning on support via command-line option. Group the sources by the # instruction set extension and explicitly set the compiler flag as appropriate. @@ -164,6 +157,7 @@ else() set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") set(mlas_platform_srcs_avx2 + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/LogisticKernelFma3.S @@ -179,17 +173,22 @@ else() ) set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") + set(mlas_platform_srcs_avx512bw + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S + ) + set_source_files_properties(${mlas_platform_srcs_avx512bw} PROPERTIES COMPILE_FLAGS "-mavx512bw") + set(mlas_platform_srcs ${mlas_platform_srcs_sse2} ${mlas_platform_srcs_avx} ${mlas_platform_srcs_avx2} ${mlas_platform_srcs_avx512f} + ${mlas_platform_srcs_avx512bw} ) - endif() - endif() add_library(onnxruntime_mlas STATIC ${mlas_common_srcs} ${mlas_platform_srcs}) -target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) +target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib ${eigen_INCLUDE_DIRS}) set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_nuphar_extern.cmake b/cmake/onnxruntime_nuphar_extern.cmake new file mode 100644 index 0000000000000..9a34e82f204d6 --- /dev/null +++ b/cmake/onnxruntime_nuphar_extern.cmake @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# this is for building extern functions in nuphar execution provider, using AVX2 +# the separation from onnxruntime_providers.cmake is to avoid unnecessary AVX2 codegen in providers +# functions built here would be dynamically switched based on if AVX2 is available from CPUID + +add_definitions(-DNUPHAR_USE_AVX2) + +set(extern_avx2_srcs + ${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern/igemv_avx2.cc + ${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern/igemv_avx2.h +) + +if (MSVC) + set_source_files_properties(${extern_avx2_srcs} PROPERTIES COMPILE_FLAGS "/arch:AVX2") +else() + set_source_files_properties(${extern_avx2_srcs} PROPERTIES COMPILE_FLAGS "-march=broadwell") +endif() + +set(nuphar_extern_srcs + ${extern_avx2_srcs} +) + +add_library(onnxruntime_nuphar_extern ${nuphar_extern_srcs}) + +if (onnxruntime_USE_MKLML) + add_definitions(-DNUPHAR_USE_MKL) + target_include_directories(onnxruntime_nuphar_extern PRIVATE ${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern ${MKLML_INCLUDE_DIR}) + add_dependencies(onnxruntime_nuphar_extern project_mklml) +else() + target_include_directories(onnxruntime_nuphar_extern PRIVATE ${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern) +endif() + +set_target_properties(onnxruntime_nuphar_extern PROPERTIES FOLDER "ONNXRuntime") + +list(APPEND onnxruntime_EXTERNAL_LIBRARIES onnxruntime_nuphar_extern) +list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES onnxruntime_nuphar_extern) +link_directories(${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 0447e4814d37d..4c0abeb970aaa 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -25,6 +25,16 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" ) +file(GLOB onnxruntime_cpu_automl_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/automl_ops/cpu_automl_kernels.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu_automl_kernels.cc" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_types.h" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_types.cc" + "${ONNXRUNTIME_ROOT}/automl_ops/automl_featurizers.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu/*.h" + "${ONNXRUNTIME_ROOT}/automl_ops/cpu/*.cc" +) + file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/*.h" "${ONNXRUNTIME_ROOT}/core/providers/*.cc" @@ -38,6 +48,10 @@ if(onnxruntime_USE_NGRAPH) set(PROVIDERS_NGRAPH onnxruntime_providers_ngraph) list(APPEND ONNXRUNTIME_PROVIDER_NAMES ngraph) endif() +if(onnxruntime_USE_NUPHAR) + set(PROVIDERS_NUPHAR onnxruntime_providers_nuphar) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES nuphar) +endif() if(onnxruntime_USE_CUDA) set(PROVIDERS_CUDA onnxruntime_providers_cuda) list(APPEND ONNXRUNTIME_PROVIDER_NAMES cuda) @@ -55,17 +69,30 @@ if(onnxruntime_USE_NNAPI) list(APPEND ONNXRUNTIME_PROVIDER_NAMES nnapi) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) -# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio -source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) + +set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) # disable contrib ops conditionally -if(onnxruntime_DISABLE_CONTRIB_OPS) - add_library(onnxruntime_providers ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) -else() - add_library(onnxruntime_providers ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs} ${onnxruntime_cpu_contrib_ops_srcs}) +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) +endif() + +if (onnxruntime_USE_AUTOML) + source_group(TREE ${ONNXRUNTIME_ROOT}/ FILES ${onnxruntime_cpu_automl_cc_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_automl_cc_srcs}) endif() +add_library(onnxruntime_providers ${onnxruntime_providers_src}) onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework gsl onnx onnx_proto protobuf::libprotobuf) + +if (onnxruntime_USE_AUTOML) + add_dependencies(onnxruntime_providers automl_featurizers) + onnxruntime_add_include_to_target(onnxruntime_providers automl_featurizers) + target_link_libraries(onnxruntime_providers automl_featurizers) +endif() + if(HAS_DEPRECATED_COPY) #temporarily ignore this warning #see: https://en.wikipedia.org/wiki/Rule_of_three_(C%2B%2B_programming) @@ -78,17 +105,6 @@ if(HAS_DEPRECATED_COPY) set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/where_op.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) endif() -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT MSVC) - # For x86 platforms it is important to pass this flag to compiler. Without this gemmlowp will use slow reference code. - # These optimizations are not enabled on MSVC so excluding it. - message("enabling optimizations for gemmlowp") - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/matmul_integer.cc" PROPERTIES COMPILE_FLAGS "-msse4.1") - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/quantize_linear_matmul.cc" PROPERTIES COMPILE_FLAGS "-msse4.1") - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/qlinearconv.cc" PROPERTIES COMPILE_FLAGS "-msse4.1") - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/conv_integer.cc" PROPERTIES COMPILE_FLAGS "-msse4.1") -endif() - -set(gemmlowp_src ${PROJECT_SOURCE_DIR}/external/gemmlowp) set(re2_src ${ONNXRUNTIME_ROOT}/../cmake/external/re2) target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${gemmlowp_src} ${re2_src}) add_dependencies(onnxruntime_providers gsl onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -306,6 +322,43 @@ endif() file(COPY ${onnxruntime_providers_openvino_py_srcs} DESTINATION ${onnxruntime_BINARY_DIR}) endif() +if (onnxruntime_USE_NUPHAR) + add_definitions(-DUSE_NUPHAR=1) + + if (NOT onnxruntime_USE_TVM) + message(FATAL_ERROR "onnxruntime_USE_TVM required for onnxruntime_USE_NUPHAR") + endif() + + if (NOT onnxruntime_USE_LLVM) + message(FATAL_ERROR "onnxruntime_USE_LLVM required for onnxruntime_USE_NUPHAR") + endif() + + include(onnxruntime_nuphar_extern.cmake) + + file(GLOB_RECURSE onnxruntime_providers_nuphar_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/nuphar/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/nuphar/*.cc" + ) + + # following files required different build flag for AVX2 in separate onnxruntime_nuphar_extern.cmake file + list (REMOVE_ITEM onnxruntime_providers_nuphar_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern/igemv_avx2.cc") + list (REMOVE_ITEM onnxruntime_providers_nuphar_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/nuphar/extern/igemv_avx2.h") + + if (onnxruntime_USE_MKLML) + add_definitions(-DNUPHAR_USE_MKL) + endif() + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_nuphar_cc_srcs}) + add_library(onnxruntime_providers_nuphar ${onnxruntime_providers_nuphar_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_nuphar onnxruntime_common onnxruntime_framework gsl onnx onnx_proto protobuf::libprotobuf) + set_target_properties(onnxruntime_providers_nuphar PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_nuphar PRIVATE ${ONNXRUNTIME_ROOT} ${TVM_INCLUDES} ${eigen_INCLUDE_DIRS}) + set_target_properties(onnxruntime_providers_nuphar PROPERTIES LINKER_LANGUAGE CXX) + target_compile_options(onnxruntime_providers_nuphar PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) + add_dependencies(onnxruntime_providers_nuphar ${onnxruntime_EXTERNAL_DEPENDENCIES}) + install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/nuphar DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) +endif() + if (onnxruntime_USE_NNAPI) add_definitions(-DUSE_NNAPI=1) option(DNN_READ_ONNX "" ON) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c9fcb91ff359d..a01317df6eff4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -73,6 +73,7 @@ set(onnxruntime_pybind11_state_libs ${PROVIDERS_TENSORRT} ${PROVIDERS_NGRAPH} ${PROVIDERS_OPENVINO} + ${PROVIDERS_NUPHAR} ${PROVIDERS_NNAPI} onnxruntime_optimizer onnxruntime_providers @@ -234,3 +235,15 @@ if (onnxruntime_USE_MKLML) $/onnxruntime/capi/ ) endif() + +if (onnxruntime_USE_NUPHAR) + file(GLOB onnxruntime_python_nuphar_test_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*" + ) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_nuphar_test_srcs} + $ + ) +endif() \ No newline at end of file diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3223e263a21e1..c2cf520cd32ed 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -126,6 +126,12 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${TEST_SRC_DIR}/contrib_ops/*.cc") endif() +if(onnxruntime_USE_AUTOML) + list(APPEND onnxruntime_test_providers_src_patterns + "${TEST_SRC_DIR}/automl_ops/*.h" + "${TEST_SRC_DIR}/automl_ops/*.cc") +endif() + file(GLOB onnxruntime_test_providers_src CONFIGURE_DEPENDS ${onnxruntime_test_providers_src_patterns}) file(GLOB_RECURSE onnxruntime_test_providers_cpu_src CONFIGURE_DEPENDS @@ -209,6 +215,10 @@ if(onnxruntime_USE_NNAPI) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_AUTOML) + list(APPEND onnxruntime_test_providers_dependencies automl_featurizers) +endif() + file(GLOB_RECURSE onnxruntime_test_tvm_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/tvm/*.h" "${ONNXRUNTIME_ROOT}/test/tvm/*.cc" @@ -219,6 +229,13 @@ file(GLOB_RECURSE onnxruntime_test_openvino_src "${ONNXRUNTIME_ROOT}/test/openvino/*.cc" ) +if(onnxruntime_USE_NUPHAR) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/framework/nuphar/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nuphar) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nuphar) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_nuphar) +endif() + if (onnxruntime_ENABLE_MICROSOFT_INTERNAL) include(onnxruntime_unittests_internal.cmake) endif() @@ -231,6 +248,7 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_TENSORRT} ${PROVIDERS_NGRAPH} ${PROVIDERS_OPENVINO} + ${PROVIDERS_NUPHAR} ${PROVIDERS_NNAPI} onnxruntime_optimizer onnxruntime_providers @@ -471,7 +489,12 @@ set(onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/TestCase.h ${onnx_test_runner_src_dir}/onnxruntime_event.h ${onnx_test_runner_src_dir}/sync_api.h - ${onnx_test_runner_src_dir}/sync_api.cc) + ${onnx_test_runner_src_dir}/sync_api.cc + ${onnx_test_runner_src_dir}/callback.h + ${onnx_test_runner_src_dir}/callback.cc + ${onnx_test_runner_src_dir}/mem_buffer.h + ${onnx_test_runner_src_dir}/tensorprotoutils.h + ${onnx_test_runner_src_dir}/tensorprotoutils.cc) if(WIN32) set(wide_get_opt_src_dir ${TEST_SRC_DIR}/win_getopt/wide) @@ -505,13 +528,19 @@ onnxruntime_add_include_to_target(onnx_test_runner gsl) target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") +if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") + endif() +endif() + install(TARGETS onnx_test_runner ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) if(onnxruntime_BUILD_BENCHMARKS) - add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc ${TEST_SRC_DIR}/onnx/microbenchmark/model_init.cc) + add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} benchmark) onnxruntime_add_include_to_target(onnxruntime_benchmark gsl) if(WIN32) @@ -585,6 +614,12 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) endif() +if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") + endif() +endif() + # shared lib if (onnxruntime_BUILD_SHARED_LIB) add_library(onnxruntime_mocked_allocator ${ONNXRUNTIME_ROOT}/test/util/test_allocator.cc) @@ -606,7 +641,6 @@ if (onnxruntime_BUILD_SHARED_LIB) endif() if (NOT(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")) #for some reason, these tests are failing. Need investigation. - list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_tensor_loader.cc) if (onnxruntime_USE_FULL_PROTOBUF) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc) endif() diff --git a/cmake/onnxruntime_util.cmake b/cmake/onnxruntime_util.cmake index a8b611d7c99c0..feea9f90ee80f 100644 --- a/cmake/onnxruntime_util.cmake +++ b/cmake/onnxruntime_util.cmake @@ -8,8 +8,17 @@ file(GLOB_RECURSE onnxruntime_util_srcs CONFIGURE_DEPENDS source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_util_srcs}) +if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT MSVC) + # For x86 platforms it is important to pass this flag to compiler. Without this gemmlowp will use slow reference code. + # These optimizations are not enabled on MSVC so excluding it. + message("enabling optimizations for gemmlowp") + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/util/gemmlowp_common.cc" PROPERTIES COMPILE_FLAGS "-msse4.1") +endif() + +set(gemmlowp_src ${PROJECT_SOURCE_DIR}/external/gemmlowp) + add_library(onnxruntime_util ${onnxruntime_util_srcs}) -target_include_directories(onnxruntime_util PRIVATE ${ONNXRUNTIME_ROOT} ${MKLML_INCLUDE_DIR} PUBLIC ${eigen_INCLUDE_DIRS}) +target_include_directories(onnxruntime_util PRIVATE ${ONNXRUNTIME_ROOT} ${MKLML_INCLUDE_DIR} ${gemmlowp_src} PUBLIC ${eigen_INCLUDE_DIRS}) onnxruntime_add_include_to_target(onnxruntime_util onnxruntime_common onnxruntime_framework gsl onnx onnx_proto protobuf::libprotobuf) if(UNIX) target_compile_options(onnxruntime_util PUBLIC "-Wno-error=comment") diff --git a/cmake/patches/ngraph/ngraph_fix_install_error.patch b/cmake/patches/ngraph/ngraph_fix_install_error.patch deleted file mode 100644 index ddabbb7d86527..0000000000000 --- a/cmake/patches/ngraph/ngraph_fix_install_error.patch +++ /dev/null @@ -1,127 +0,0 @@ -From 280fbc003ea2794adb24d6a81d42db838a793dd9 Mon Sep 17 00:00:00 2001 -From: Sang Ik Lee -Date: Mon, 15 Apr 2019 16:11:27 -0700 -Subject: [PATCH] CMAKE_CFG_INTDIR does not work at install time. Use - CMAKE_INSTALL_CONFIG_NAME on Windows. - ---- - CMakeLists.txt | 7 ++++++- - cmake/external_mkldnn.cmake | 22 +++++++++++----------- - cmake/external_tbb.cmake | 4 ++-- - cmake/external_tbb_prebuilt.cmake | 6 +++--- - 4 files changed, 22 insertions(+), 17 deletions(-) - -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 2a21ed3a3..a695e217f 100755 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -390,12 +390,17 @@ endif() - - set(NGRAPH_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/ngraph) - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) --set(NGRAPH_LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}/${CMAKE_CFG_INTDIR}) - if(WIN32) -+ set(NGRAPH_LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}/${CMAKE_CFG_INTDIR}) -+ set(NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY ${NGRAPH_BUILD_DIR}/\${CMAKE_INSTALL_CONFIG_NAME}) - set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) - set(NGRAPH_ARCHIVE_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}/${CMAKE_CFG_INTDIR}) -+ set(NGRAPH_ARCHIVE_INSTALLSRC_DIRECTORY ${NGRAPH_BUILD_DIR}/\${CMAKE_INSTALL_CONFIG_NAME}) - set(CMAKE_PDB_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) -+else() -+ set(NGRAPH_LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) -+ set(NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY ${NGRAPH_BUILD_DIR}) - endif() - - set(EXTERNAL_INSTALL_DIR ${CMAKE_BINARY_DIR}/external) -diff --git a/cmake/external_mkldnn.cmake b/cmake/external_mkldnn.cmake -index 25445bf0b..7874aca76 100644 ---- a/cmake/external_mkldnn.cmake -+++ b/cmake/external_mkldnn.cmake -@@ -312,12 +312,12 @@ endif() - if(WIN32) - install( - FILES -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLML_LIB} -- ${NGRAPH_ARCHIVE_OUTPUT_DIRECTORY}/${MKLML_IMPLIB} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${OMP_LIB} -- ${NGRAPH_ARCHIVE_OUTPUT_DIRECTORY}/${OMP_IMPLIB} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLDNN_LIB} -- ${NGRAPH_ARCHIVE_OUTPUT_DIRECTORY}/${MKLDNN_IMPLIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLML_LIB} -+ ${NGRAPH_ARCHIVE_INSTALLSRC_DIRECTORY}/${MKLML_IMPLIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${OMP_LIB} -+ ${NGRAPH_ARCHIVE_INSTALLSRC_DIRECTORY}/${OMP_IMPLIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLDNN_LIB} -+ ${NGRAPH_ARCHIVE_INSTALLSRC_DIRECTORY}/${MKLDNN_IMPLIB} - DESTINATION - ${NGRAPH_INSTALL_LIB} - OPTIONAL -@@ -325,9 +325,9 @@ if(WIN32) - else() - install( - FILES -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLML_LIB} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${OMP_LIB} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLDNN_LIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLML_LIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${OMP_LIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLDNN_LIB} - DESTINATION - ${NGRAPH_INSTALL_LIB} - OPTIONAL -@@ -335,8 +335,8 @@ else() - if(NGRAPH_LIB_VERSIONING_ENABLE) - install( - FILES -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLDNN_SHORT_LIB} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${MKLDNN_FULL_LIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLDNN_SHORT_LIB} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${MKLDNN_FULL_LIB} - DESTINATION - ${NGRAPH_INSTALL_LIB} - OPTIONAL -diff --git a/cmake/external_tbb.cmake b/cmake/external_tbb.cmake -index 761c5b3bd..6960ea929 100644 ---- a/cmake/external_tbb.cmake -+++ b/cmake/external_tbb.cmake -@@ -63,10 +63,10 @@ if(NGRAPH_TBB_ENABLE) - ${TBB_BUILD_DIR}/${TBB_LIB}.${TBB_SOVER} - DESTINATION ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}) - endif() -- install(FILES ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${TBB_LIB} -+ install(FILES ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${TBB_LIB} - DESTINATION ${NGRAPH_INSTALL_LIB}) - if(LINUX) -- install(FILES ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${TBB_LIB}.${TBB_SOVER} -+ install(FILES ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${TBB_LIB}.${TBB_SOVER} - DESTINATION ${NGRAPH_INSTALL_LIB}) - endif() - add_library(libtbb INTERFACE) -diff --git a/cmake/external_tbb_prebuilt.cmake b/cmake/external_tbb_prebuilt.cmake -index 3e1d0688f..a1cf1922a 100644 ---- a/cmake/external_tbb_prebuilt.cmake -+++ b/cmake/external_tbb_prebuilt.cmake -@@ -69,8 +69,8 @@ if (WIN32) - DEPENDEES download - ) - -- install(FILES ${NGRAPH_ARCHIVE_OUTPUT_DIRECTORY}/${TBB_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX} -- ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX} -+ install(FILES ${NGRAPH_ARCHIVE_INSTALLSRC_DIRECTORY}/${TBB_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX} -+ ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX} - DESTINATION ${NGRAPH_INSTALL_LIB}) - elseif(APPLE) - set(TBB_LINK_LIBS -@@ -82,7 +82,7 @@ elseif(APPLE) - COMMENT "Move tbb libraries to ngraph build directory" - ) - -- install(FILES ${NGRAPH_LIBRARY_OUTPUT_DIRECTORY}/${CMAKE_SHARED_LIBRARY_PREFIX}${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX} -+ install(FILES ${NGRAPH_LIBRARY_INSTALLSRC_DIRECTORY}/${CMAKE_SHARED_LIBRARY_PREFIX}${TBB_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX} - DESTINATION ${NGRAPH_INSTALL_LIB}) - endif() - --- -2.13.0.windows.1 - diff --git a/cmake/patches/ngraph/ngraph_fix_library_path.patch b/cmake/patches/ngraph/ngraph_fix_library_path.patch deleted file mode 100644 index aaa63e96e7d78..0000000000000 --- a/cmake/patches/ngraph/ngraph_fix_library_path.patch +++ /dev/null @@ -1,33 +0,0 @@ -From fcd51f874f4a96fb4ca77d762ed39ea1bf3f2c0d Mon Sep 17 00:00:00 2001 -From: Junfeng Dong -Date: Wed, 17 Apr 2019 13:42:42 -0700 -Subject: [PATCH] Fix dll library load path on Windows. - ---- - src/ngraph/runtime/backend_manager.cpp | 3 ++- - 1 file changed, 2 insertions(+), 1 deletion(-) - -diff --git a/src/ngraph/runtime/backend_manager.cpp b/src/ngraph/runtime/backend_manager.cpp -index eaa8fc26a..4d35c63ec 100644 ---- a/src/ngraph/runtime/backend_manager.cpp -+++ b/src/ngraph/runtime/backend_manager.cpp -@@ -123,7 +123,7 @@ unique_ptr runtime::BackendManager::create_backend(const std:: - static string find_my_file() - { - #ifdef _WIN32 -- HMODULE hModule = GetModuleHandleW(NULL); -+ HMODULE hModule = GetModuleHandleW(L"ngraph.dll"); - WCHAR wpath[MAX_PATH]; - GetModuleFileNameW(hModule, wpath, MAX_PATH); - wstring ws(wpath); -@@ -157,6 +157,7 @@ DL_HANDLE runtime::BackendManager::open_shared_library(string type) - string my_directory = file_util::get_directory(find_my_file()); - string library_path = file_util::path_join(my_directory, library_name); - #ifdef _WIN32 -+ SetDllDirectory((LPCSTR)my_directory.c_str()); - handle = LoadLibrary(library_path.c_str()); - #else - handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL); --- -2.13.0.windows.1 - diff --git a/cmake/patches/ngraph/ngraph_fix_mkldnn_missing_symbol.patch b/cmake/patches/ngraph/ngraph_fix_mkldnn_missing_symbol.patch new file mode 100644 index 0000000000000..96504c910003a --- /dev/null +++ b/cmake/patches/ngraph/ngraph_fix_mkldnn_missing_symbol.patch @@ -0,0 +1,64 @@ + cmake/external_mkldnn.cmake | 1 + + cmake/mkldnn_fix_missing_symbol.patch | 99 +++++++++++++++++++++++++++++++++++ + 2 files changed, 100 insertions(+) + create mode 100644 cmake/mkldnn_fix_missing_symbol.patch + +diff --git a/cmake/external_mkldnn.cmake b/cmake/external_mkldnn.cmake +index 7874aca76..bbae6d1a4 100644 +--- a/cmake/external_mkldnn.cmake ++++ b/cmake/external_mkldnn.cmake +@@ -194,7 +194,8 @@ if (WIN32) + CONFIGURE_COMMAND + PATCH_COMMAND ${MKLDNN_PATCH_REVERT_COMMAND} + COMMAND git apply --ignore-space-change --ignore-whitespace ${CMAKE_SOURCE_DIR}/cmake/${MKLDNN_PATCH_FILE} + COMMAND git apply --ignore-space-change --ignore-whitespace ${CMAKE_SOURCE_DIR}/cmake/mkldnn_fix_memory.patch ++ COMMAND git apply --ignore-space-change --ignore-whitespace ${CMAKE_SOURCE_DIR}/cmake/mkldnn_fix_missing_symbol.patch + CMAKE_GENERATOR ${CMAKE_GENERATOR} + CMAKE_GENERATOR_PLATFORM ${CMAKE_GENERATOR_PLATFORM} + CMAKE_GENERATOR_TOOLSET ${CMAKE_GENERATOR_TOOLSET} +diff --git a/cmake/mkldnn_fix_missing_symbol.patch b/cmake/mkldnn_fix_missing_symbol.patch +new file mode 100644 +index 000000000..ea1a3bd61 +--- /dev/null ++++ b/cmake/mkldnn_fix_missing_symbol.patch +@@ -0,0 +1,40 @@ ++commit d485a54ac2b07b7349dabd833961415315a18fea ++Author: Denis Samoilov ++Date: Sun Apr 14 20:11:33 2019 -0700 ++ ++ cpu: gemv: fix unresolved symbol ++ ++ Fixes #456 ++ ++diff --git a/src/cpu/gemm/gemm_driver.cpp b/src/cpu/gemm/gemm_driver.cpp ++index 0773b212..df7bc44d 100644 ++--- a/src/cpu/gemm/gemm_driver.cpp +++++ b/src/cpu/gemm/gemm_driver.cpp ++@@ -1304,10 +1304,8 @@ static mkldnn_status_t gemm_threading_driver( ++ (float *) arg->co); ++ } ++ ++- if (data_traits::data_type == data_type::s8) { ++- if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { ++- return mkldnn_success; ++- } +++ if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { +++ return mkldnn_success; ++ } ++ ++ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); ++diff --git a/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp ++index 73d50e40..81646a43 100644 ++--- a/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp +++++ b/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp ++@@ -29,6 +29,10 @@ namespace cpu { ++ template ++ int gemm_s8u8s32_jump_to_gemv_s8u8s32(T *arg); ++ +++template <> +++int gemm_s8u8s32_jump_to_gemv_s8u8s32( +++ gemm_info_t *arg) { return 0; } +++ ++ template <> ++ int gemm_s8u8s32_jump_to_gemv_s8u8s32( ++ gemm_info_t *arg) { diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 696e019834ebe..c3bf76f2d087a 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -19,19 +19,19 @@ CMake creates a target to this project @@ -54,7 +54,7 @@ CMake creates a target to this project /> - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.snk b/csharp/OnnxRuntime.snk similarity index 100% rename from csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.snk rename to csharp/OnnxRuntime.snk diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs index 0025bb8429f4e..6fc9b63e1a163 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample/Program.cs @@ -6,7 +6,7 @@ using System.Text; using System.IO; using Microsoft.ML.OnnxRuntime; -using System.Numerics.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; namespace CSharpUsage { @@ -26,7 +26,7 @@ static void UseApi() // Optional : Create session options and set the graph optimization level for the session SessionOptions options = new SessionOptions(); - options.SetSessionGraphOptimizationLevel(2); + options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED; using (var session = new InferenceSession(modelPath, options)) { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs index 096501019771e..bb159ea29e577 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; -using System.Numerics.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; using System.Runtime.InteropServices; @@ -120,9 +120,15 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, case TensorElementType.UInt8: result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); break; + case TensorElementType.Int8: + result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + break; case TensorElementType.String: result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); break; + case TensorElementType.Bool: + result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + break; default: throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported"); @@ -134,9 +140,8 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, internal static DisposableNamedOnnxValue CreateFromOnnxValue(string name, IntPtr nativeOnnxValue) { IntPtr allocator = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateDefaultAllocator(out allocator)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAllocatorWithDefaultOptions(out allocator)); var ret = CreateFromOnnxValue(name, nativeOnnxValue, allocator); - NativeMethods.OrtReleaseAllocator(allocator); return (DisposableNamedOnnxValue)ret; } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 5f89bad8bbe9b..79643029561da 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -19,6 +19,8 @@ public class InferenceSession : IDisposable { protected IntPtr _nativeHandle; protected Dictionary _inputMetadata, _outputMetadata; + private SessionOptions _builtInSessionOptions = null; + private RunOptions _builtInRunOptions = null; #region Public API @@ -28,10 +30,12 @@ public class InferenceSession : IDisposable /// /// public InferenceSession(string modelPath) - : this(modelPath, SessionOptions.Default) { + _builtInSessionOptions = new SessionOptions(); // need to be disposed + Init(modelPath, _builtInSessionOptions); } + /// /// Constructs an InferenceSession from a model file, with some additional session options /// @@ -39,52 +43,13 @@ public InferenceSession(string modelPath) /// public InferenceSession(string modelPath, SessionOptions options) { - var envHandle = OnnxRuntime.Handle; - - _nativeHandle = IntPtr.Zero; - try - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options._nativePtr, out _nativeHandle)); - else - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options._nativePtr, out _nativeHandle)); - - // Initialize input/output metadata - _inputMetadata = new Dictionary(); - _outputMetadata = new Dictionary(); - - // get input count - UIntPtr inputCount = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); - - // get all the output names - for (ulong i = 0; i < (ulong)inputCount; i++) - { - var iname = GetInputName(i); - _inputMetadata[iname] = GetInputMetadata(i); - } - // get output count - UIntPtr outputCount = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); - - // get all the output names - for (ulong i = 0; i < (ulong)outputCount; i++) - { - _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); - } - - } - catch (OnnxRuntimeException e) - { - if (_nativeHandle != IntPtr.Zero) - { - NativeMethods.OrtReleaseSession(_nativeHandle); - _nativeHandle = IntPtr.Zero; - } - throw e; - } + Init(modelPath, options); } + + /// + /// Meta data regarding the input nodes, keyed by input names + /// public IReadOnlyDictionary InputMetadata { get @@ -93,6 +58,9 @@ public IReadOnlyDictionary InputMetadata } } + /// + /// Metadata regarding the output nodes, keyed by output names + /// public IReadOnlyDictionary OutputMetadata { get @@ -101,11 +69,12 @@ public IReadOnlyDictionary OutputMetadata } } + /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// /// - /// Output Tensors in a Collection of NamedOnnxValue + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs) { string[] outputNames = new string[_outputMetadata.Count]; @@ -118,21 +87,22 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl /// /// /// - /// Output Tensors in a Collection of NamedOnnxValue + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames) { - return Run(inputs, outputNames, RunOptions.Default); + IDisposableReadOnlyCollection result = null; + result = Run(inputs, outputNames, _builtInRunOptions); + return result; } /// - /// Runs the loaded model for the given inputs, and fetches the specified outputs in . + /// Runs the loaded model for the given inputs, and fetches the specified outputs in . /// /// /// /// - /// Output Tensors in a Collection of NamedOnnxValue - //TODO: kept internal until RunOptions is made public - internal IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. + public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) { var inputNames = new string[inputs.Count]; var inputTensors = new IntPtr[inputs.Count]; @@ -154,8 +124,7 @@ internal IDisposableReadOnlyCollection Run(IReadOnlyCo IntPtr status = NativeMethods.OrtRun( this._nativeHandle, - IntPtr.Zero, // TODO: use Run options when Run options creation API is available - // Passing null uses the default run options in the C-api + options.Handle, inputNames, inputTensors, (UIntPtr)(inputTensors.Length), @@ -192,7 +161,8 @@ internal IDisposableReadOnlyCollection Run(IReadOnlyCo // always unpin the input buffers, and delete the native Onnx value objects for (int i = 0; i < inputs.Count; i++) { - NativeMethods.OrtReleaseValue(inputTensors[i]); // this should not release the buffer, but should delete the native tensor object + NativeMethods.OrtReleaseValue(inputTensors[i]); // For elementary type Tensors, this should not release the buffer, but should delete the native tensor object. + // For string tensors, this releases the native memory allocated for the tensor, including the buffer pinnedBufferHandles[i].Dispose(); } } @@ -211,6 +181,58 @@ internal ModelMetadata ModelMetadata #endregion #region private methods + + protected void Init(string modelPath, SessionOptions options) + { + var envHandle = OnnxRuntime.Handle; + + _nativeHandle = IntPtr.Zero; + try + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle)); + else + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle)); + + // Initialize input/output metadata + _inputMetadata = new Dictionary(); + _outputMetadata = new Dictionary(); + + // get input count + UIntPtr inputCount = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); + + // get all the output names + for (ulong i = 0; i < (ulong)inputCount; i++) + { + var iname = GetInputName(i); + _inputMetadata[iname] = GetInputMetadata(i); + } + // get output count + UIntPtr outputCount = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); + + // get all the output names + for (ulong i = 0; i < (ulong)outputCount; i++) + { + _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); + } + + } + catch (OnnxRuntimeException e) + { + if (_nativeHandle != IntPtr.Zero) + { + NativeMethods.OrtReleaseSession(_nativeHandle); + _nativeHandle = IntPtr.Zero; + } + throw e; + } + + _builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call + } + + private string GetOutputName(ulong index) { IntPtr nameHandle = IntPtr.Zero; @@ -358,6 +380,15 @@ protected virtual void Dispose(bool disposing) if (disposing) { // cleanup managed resources + if (_builtInSessionOptions != null) + { + _builtInSessionOptions.Dispose(); + } + + if (_builtInRunOptions != null) + { + _builtInRunOptions.Dispose(); + } } // cleanup unmanaged resources @@ -426,24 +457,5 @@ internal class ModelMetadata //TODO: placeholder for Model metadata. Currently C-API does not expose this. } - /// Sets various runtime options. - /// TODO: currently uses Default options only. kept internal until fully implemented - internal class RunOptions - { - protected static readonly Lazy _default = new Lazy(() => new RunOptions()); - - public static RunOptions Default - { - get - { - return _default.Value; - } - } - - private void RuntOptions() - { - - } - } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index f7fbdcda281be..c966393f444bf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -3,10 +3,11 @@ netstandard1.1 AnyCPU;x86 + 7.2 true true false - OnnxRuntime.snk + ..\..\OnnxRuntime.snk ..\.. @@ -24,12 +25,11 @@ LICENSE.txt https://go.microsoft.com/fwlink/?linkid=2049168 - Release Def: + Release Def: Branch: $(BUILD_SOURCEBRANCH) Commit: $(BUILD_SOURCEVERSION) Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) - true @@ -39,6 +39,7 @@ + + + + + + + + + + + + @@ -147,8 +211,8 @@ - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs index dfbd5e4899577..3ac360e67ebcd 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; using System.Text; -using System.Numerics.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; using System.Buffers; using System.Collections; using System.Diagnostics; @@ -162,6 +162,15 @@ out nativeElementType )) { } + else if (TryPinAsTensor(out pinnedMemoryHandle, + out dataBufferPointer, + out dataBufferLength, + out shape, + out rank, + out nativeElementType + )) + { + } else if (TryPinAsTensor(out pinnedMemoryHandle, out dataBufferPointer, out dataBufferLength, @@ -171,41 +180,93 @@ out nativeElementType )) { } - //TODO: add other types - else + // special case for string Tensor, data needs to be copied to the native buffer + else if (!(_value is Tensor)) { // nothing to cleanup here, since no memory has been pinned throw new NotSupportedException("The inference value " + nameof(_value) + " is not of a supported type"); } - Debug.Assert(dataBufferPointer != IntPtr.Zero, "dataBufferPointer must be non-null after obtaining the pinned buffer"); - - // copy to an ulong[] shape to match size_t[] - long[] longShape = new long[rank]; - for (int i = 0; i < rank; i++) + if (_value is Tensor) { - longShape[i] = shape[i]; - } + // calculate native tensor length (sum of string lengths in utf-8) + var tensorValue = _value as Tensor; + int totalLength = 0; + for (int i = 0; i < tensorValue.Length; i++) + { + totalLength += Encoding.UTF8.GetByteCount(tensorValue.GetValue(i)); + } - IntPtr status = NativeMethods.OrtCreateTensorWithDataAsOrtValue( - NativeMemoryAllocatorInfo.DefaultInstance.Handle, - dataBufferPointer, - (UIntPtr)(dataBufferLength), - longShape, - (UIntPtr)rank, - nativeElementType, - out onnxValue - ); - try - { - NativeApiStatus.VerifySuccess(status); + long[] longShape = new long[tensorValue.Dimensions.Length]; + for (int i = 0; i < tensorValue.Dimensions.Length; i++) + { + longShape[i] = tensorValue.Dimensions[i]; + } + + // allocate the native tensor + IntPtr nativeTensor = IntPtr.Zero; + try + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorAsOrtValue( + NativeMemoryAllocator.DefaultInstance.Handle, + longShape, + (UIntPtr)(longShape.Length), + TensorElementType.String, + out nativeTensor + )); + + // fill the native tensor, using GetValue(index) from the Tensor + string[] stringsInTensor = new string[tensorValue.Length]; + for (int i = 0; i < tensorValue.Length; i++) + { + stringsInTensor[i] = tensorValue.GetValue(i); + } + NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(nativeTensor, stringsInTensor, (UIntPtr)tensorValue.Length)); + } + catch (OnnxRuntimeException e) + { + if (nativeTensor != IntPtr.Zero) + { + NativeMethods.OrtReleaseValue(nativeTensor); + throw e; + } + } + + onnxValue = nativeTensor; // set the output + pinnedMemoryHandle = default; // dummy value for the output } - catch (OnnxRuntimeException e) + else { - pinnedMemoryHandle.Dispose(); - throw e; + Debug.Assert(dataBufferPointer != IntPtr.Zero, "dataBufferPointer must be non-null after obtaining the pinned buffer"); + + // copy to an ulong[] shape to match size_t[] + long[] longShape = new long[rank]; + for (int i = 0; i < rank; i++) + { + longShape[i] = shape[i]; + } + + IntPtr status = NativeMethods.OrtCreateTensorWithDataAsOrtValue( + NativeMemoryAllocatorInfo.DefaultInstance.Handle, + dataBufferPointer, + (UIntPtr)(dataBufferLength), + longShape, + (UIntPtr)rank, + nativeElementType, + out onnxValue + ); + try + { + NativeApiStatus.VerifySuccess(status); + } + catch (OnnxRuntimeException e) + { + pinnedMemoryHandle.Dispose(); + throw e; + } + } } @@ -224,7 +285,9 @@ out TensorElementType nativeElementType dataBufferLength = 0; shape = null; rank = 0; - pinnedMemoryHandle = default(MemoryHandle); + pinnedMemoryHandle = default; + + Debug.Assert(typeof(T) != typeof(string), "NamedOnnxValue.TryPinAsTensor() must not be called with a string Tensor value"); if (_value is Tensor) { @@ -299,15 +362,21 @@ out TensorElementType nativeElementType nativeElementType = TensorElementType.UInt8; dataBufferLength = dt.Buffer.Length * sizeof(byte); } + else if (typeof(T) == typeof(sbyte)) + { + nativeElementType = TensorElementType.Int8; + dataBufferLength = dt.Buffer.Length * sizeof(sbyte); + } else if (typeof(T) == typeof(string)) { nativeElementType = TensorElementType.String; dataBufferLength = dt.Buffer.Length * IntPtr.Size; } - //TODO: Not supporting boolean for now. bool is non-blittable, the interop needs some care, and possibly need to copy - //else if (typeof(T) == typeof(bool)) - //{ - //} + else if (typeof(T) == typeof(bool)) + { + nativeElementType = TensorElementType.Bool; + dataBufferLength = dt.Buffer.Length * sizeof(bool); // Assumes sizeof(BOOL) is always 1 byte in native + } else { //TODO: may extend the supported types @@ -397,10 +466,18 @@ public static void GetTypeAndWidth(TensorElementType elemType, out Type type, ou type = typeof(byte); width = sizeof(byte); break; + case TensorElementType.Int8: + type = typeof(sbyte); + width = sizeof(sbyte); + break; case TensorElementType.String: type = typeof(byte); width = sizeof(byte); break; + case TensorElementType.Bool: + type = typeof(bool); + width = sizeof(bool); + break; default: type = null; width = 0; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs index a9b4e60f5ac58..50b961bf23eec 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs @@ -77,22 +77,18 @@ protected override bool ReleaseHandle() internal class NativeMemoryAllocator : SafeHandle { - protected static readonly Lazy _defaultInstance = new Lazy(CreateDefaultCpuAllocator); + protected static readonly Lazy _defaultInstance = new Lazy(GetDefaultCpuAllocator); - private static NativeMemoryAllocator CreateDefaultCpuAllocator() + private static NativeMemoryAllocator GetDefaultCpuAllocator() { IntPtr allocator = IntPtr.Zero; try { - IntPtr status = NativeMethods.OrtCreateDefaultAllocator(out allocator); + IntPtr status = NativeMethods.OrtGetAllocatorWithDefaultOptions(out allocator); NativeApiStatus.VerifySuccess(status); } catch (Exception e) { - if (allocator != IntPtr.Zero) - { - Delete(allocator); - } throw e; } @@ -124,7 +120,7 @@ public override bool IsInvalid } } - internal IntPtr Handle + internal IntPtr Handle { get { @@ -138,15 +134,8 @@ protected NativeMemoryAllocator(IntPtr allocator) this.handle = allocator; } - - protected static void Delete(IntPtr allocator) - { - NativeMethods.OrtReleaseAllocator(allocator); - } - protected override bool ReleaseHandle() { - Delete(this.handle); return true; } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 4c213ec66d58e..03650989e4604 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -130,6 +130,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix); @@ -154,11 +157,14 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogVerbosityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogVerbosityLevel); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogSeverityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogSeverityLevel); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel); + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel); ///** @@ -175,12 +181,43 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id); - //[DllImport(nativeLib, CharSet = charSet)] - //public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nuphar(IntPtr /*(OrtSessionOptions*) */ options, int allow_unaligned_buffers, string settings); //[DllImport(nativeLib, CharSet = charSet)] //public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path); + #endregion + + #region RunOptions API + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtCreateRunOptions( out IntPtr /* OrtRunOptions** */ runOptions); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern void OrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/options); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsSetRunLogVerbosityLevel(IntPtr /* OrtRunOptions* */ options, LogLevel value); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsSetRunTag(IntPtr /* OrtRunOptions* */ options, string /* const char* */ runTag); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsGetRunLogVerbosityLevel(IntPtr /* OrtRunOptions* */ options, out LogLevel verbosityLevel); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */options, out IntPtr /* const char** */ runtag); + + // Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions + // will exit as soon as possible if the flag is true. + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsEnableTerminate(IntPtr /* OrtRunOptions* */ options); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsDisableTerminate(IntPtr /* OrtRunOptions* */ options); + + + #endregion #region Allocator/AllocatorInfo API @@ -223,10 +260,7 @@ public enum MemoryType public static extern void OrtReleaseAllocatorInfo(IntPtr /*(OrtAllocatorInfo*)*/ allocatorInfo); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/OrtCreateDefaultAllocator(out IntPtr /*(OrtAllocator**)*/ allocator); - - [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtReleaseAllocator(IntPtr /*(OrtAllocator*)*/ allocator); + public static extern IntPtr /*(OrtStatus*)*/OrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); /// /// Release any object allocated by an allocator @@ -261,6 +295,14 @@ public enum MemoryType [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtGetTypeInfo(IntPtr /*(OrtValue*)*/ value, IntPtr /*(OrtValue**)*/ typeInfo); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtCreateTensorAsOrtValue( + IntPtr /*_Inout_ OrtAllocator* */ allocator, + long[] /*_In_ const int64_t* */ shape, + UIntPtr /*size_t*/ shape_len, + TensorElementType type, + out IntPtr /* OrtValue** */ outputValue); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /* OrtStatus */ OrtCreateTensorWithDataAsOrtValue( IntPtr /* (const OrtAllocatorInfo*) */ allocatorInfo, @@ -276,6 +318,15 @@ public enum MemoryType [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorMutableData(IntPtr /*(OrtValue*)*/ value, out IntPtr /* (void**)*/ dataBufferHandle); + + /// \param value A tensor created from OrtCreateTensor... function. + /// \param len total data length, not including the trailing '\0' chars. + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtFillStringTensor( + IntPtr /* OrtValue */ value, + string[] /* const char* const* */s, + UIntPtr /* size_t */ s_len); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtGetStringTensorContent( IntPtr /*(OrtValue*)*/ value, diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs index 62141e9a81362..3069b4cb71b4c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs @@ -15,7 +15,7 @@ internal struct GlobalOptions //Options are currently not accessible to user public LogLevel LogLevel { get; set; } } - internal enum LogLevel + public enum LogLevel { Verbose = 0, Info = 1, @@ -51,6 +51,9 @@ public override bool IsInvalid private OnnxRuntime() //Problem: it is not possible to pass any option for a Singleton :base(IntPtr.Zero, true) { + // Check LibC version on Linux, before doing any onnxruntime initialization + CheckLibcVersionGreaterThanMinimum(); + handle = IntPtr.Zero; try { @@ -78,5 +81,32 @@ protected override bool ReleaseHandle() Delete(handle); return true; } + + [DllImport("libc", ExactSpelling = true, CallingConvention = CallingConvention.Cdecl)] + private static extern IntPtr gnu_get_libc_version(); + + private static void CheckLibcVersionGreaterThanMinimum() + { + // require libc version 2.23 or higher + var minVersion = new Version(2, 23); + var curVersion = new Version(0, 0); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + try + { + curVersion = Version.Parse(Marshal.PtrToStringAnsi(gnu_get_libc_version())); + if (curVersion >= minVersion) + return; + } + catch (Exception) + { + // trap any obscure exception + } + throw new OnnxRuntimeException(ErrorCode.RuntimeException, + $"libc.so version={curVersion} does not meet the minimun of 2.23 required by OnnxRuntime. " + + "Linux distribution should be similar to Ubuntu 16.04 or higher"); + } + } + } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.cs new file mode 100644 index 0000000000000..b40c795757397 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime +{ + /// Sets various runtime options. + public class RunOptions: IDisposable + { + private IntPtr _nativePtr; + internal IntPtr Handle + { + get + { + return _nativePtr; + } + } + + + public RunOptions() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateRunOptions(out _nativePtr)); + } + + + /// + /// LogVerbosityLevel for the Run + /// default == LogLevel.Verbose + /// + public LogLevel LogVerbosityLevel + { + get + { + LogLevel level; + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsGetRunLogVerbosityLevel(_nativePtr, out level)); + return level; + } + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetRunLogVerbosityLevel(_nativePtr, value)); + } + } + + + /// + /// Log tag to be used during the run. default = "" + /// + public string LogTag + { + get + { + string tag = null; + IntPtr tagPtr = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsGetRunTag(_nativePtr, out tagPtr)); + tag = Marshal.PtrToStringAnsi(tagPtr); // assume ANSI string + // should not release the memory of the tagPtr, because it returns the c_str() of the std::string being used inside RunOptions C++ class + return tag; + } + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetRunTag(_nativePtr, value)); + } + } + + + /// + /// Sets a flag to terminate any other Run() call that is using the same RunOptions object + /// Default = false + /// + public bool Terminate + { + get + { + return _terminate; + } + set + { + if (!_terminate && value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsEnableTerminate(_nativePtr)); + _terminate = true; + } + else if (_terminate && !value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsDisableTerminate(_nativePtr)); + _terminate = false; + } + } + } + private bool _terminate = false; //value set to default value of the C++ RunOptions + + + #region destructors disposers + + ~RunOptions() + { + Dispose(false); + } + + + public void Dispose() + { + GC.SuppressFinalize(this); + Dispose(true); + } + + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + // cleanup managed resources + } + NativeMethods.OrtReleaseRunOptions(_nativePtr); + } + + #endregion + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 4ce708687ef79..56c597f9e8411 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -4,117 +4,304 @@ using System; using System.Text; using System.Runtime.InteropServices; +using System.IO; namespace Microsoft.ML.OnnxRuntime { + /// + /// TODO Add documentation about which optimizations are enabled for each value. + /// + public enum GraphOptimizationLevel + { + ORT_DISABLE_ALL = 0, + ORT_ENABLE_BASIC = 1, + ORT_ENABLE_EXTENDED = 2, + ORT_ENABLE_ALL = 99 + } + /// /// Holds the options for creating an InferenceSession /// - public class SessionOptions:IDisposable + public class SessionOptions : IDisposable { - public IntPtr _nativePtr; - protected static readonly Lazy _default = new Lazy(MakeSessionOptionWithCpuProvider); + private IntPtr _nativePtr; private static string[] cudaDelayLoadedLibs = { "cublas64_100.dll", "cudnn64_7.dll" }; + #region Constructor and Factory methods + /// /// Constructs an empty SessionOptions /// public SessionOptions() { - NativeMethods.OrtCreateSessionOptions(out _nativePtr); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionOptions(out _nativePtr)); } + /// - /// Sets the graph optimization level for the session. Default is set to 1. + /// A helper method to constuct a SessionOptions object for CUDA execution /// - /// optimization level for the session - /// Available options are : 0, 1, 2 - /// 0 -> Disable all optimizations - /// 1 -> Enable basic optimizations - /// 2 -> Enable all optimizations - public void SetSessionGraphOptimizationLevel(uint optimization_level) + /// A SessionsOptions() object configured for execution on deviceId=0 + public static SessionOptions MakeSessionOptionWithCudaProvider() { - NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level)); + return MakeSessionOptionWithCudaProvider(0); } + /// - /// Enable Sequential Execution. By default, it is enabled. + /// A helper method to constuct a SessionOptions object for CUDA execution /// - /// - public void EnableSequentialExecution() + /// + /// A SessionsOptions() object configured for execution on deviceId + public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId = 0) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableSequentialExecution(_nativePtr)); + CheckCudaExecutionProviderDLLs(); + SessionOptions options = new SessionOptions(); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options._nativePtr, deviceId); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1); + return options; } /// - /// Disable Sequential Execution and enable Parallel Execution. + /// A helper method to construct a SessionOptions object for Nuphar execution /// - /// - public void DisableSequentialExecution() + /// settings string, comprises of comma separated key:value pairs. default is empty + /// A SessionsOptions() object configured for execution with Nuphar + public static SessionOptions MakeSessionOptionWithNupharProvider(String settings = "") + { + SessionOptions options = new SessionOptions(); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nuphar(options._nativePtr, 1, settings); + return options; + } + + #endregion + + #region Public Properties + + internal IntPtr Handle { - NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableSequentialExecution(_nativePtr)); + get + { + return _nativePtr; + } } + /// - /// Enable Mem Pattern. By default, it is enabled + /// Enable Sequential Execution. Default = true. /// /// - public void EnableMemPattern() + /// + public bool EnableSequentialExecution { - NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableMemPattern(_nativePtr)); + get + { + return _enableSequentialExecution; + } + set + { + if (!_enableSequentialExecution && value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableSequentialExecution(_nativePtr)); + _enableSequentialExecution = true; + } + else if (_enableSequentialExecution && !value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableSequentialExecution(_nativePtr)); + _enableSequentialExecution = false; + } + } } + private bool _enableSequentialExecution = true; + /// - /// Disable Mem Pattern. + /// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true. /// - /// - public void DisableMemPattern() + public bool EnableMemoryPattern { - NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableMemPattern(_nativePtr)); + get + { + return _enableMemoryPattern; + } + set + { + if (!_enableMemoryPattern && value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableMemPattern(_nativePtr)); + _enableMemoryPattern = true; + } + else if (_enableMemoryPattern && !value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableMemPattern(_nativePtr)); + _enableMemoryPattern = false; + } + } } + private bool _enableMemoryPattern = true; + /// - /// Default instance + /// Path prefix to use for output of profiling data /// - public static SessionOptions Default + public string ProfileOutputPathPrefix + { + get; set; + } = "onnxruntime_profile_"; // this is the same default in C++ implementation + + + + /// + /// Enables profiling of InferenceSession.Run() calls. Default is false + /// + public bool EnableProfiling { get { - return _default.Value; + return _enableProfiling; + } + set + { + if (!_enableProfiling && value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, ProfileOutputPathPrefix)); + _enableProfiling = true; + } + else if (_enableProfiling && !value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableProfiling(_nativePtr)); + _enableProfiling = false; + } } } + private bool _enableProfiling = false; - private static SessionOptions MakeSessionOptionWithCpuProvider() + /// + /// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled. + /// + public string OptimizedModelFilePath { - CheckLibcVersionGreaterThanMinimum(); - SessionOptions options = new SessionOptions(); - NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1); - return options; + get + { + return _optimizedModelFilePath; + } + set + { + if (value != _optimizedModelFilePath) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, value)); + _optimizedModelFilePath = value; + } + } } + private string _optimizedModelFilePath = ""; + + /// - /// A helper method to constuct a SessionOptions object for CUDA execution + /// Enables Arena allocator for the CPU memory allocations. Default is true. /// - /// A SessionsOptions() object configured for execution on deviceId=0 - public static SessionOptions MakeSessionOptionWithCudaProvider() + public bool EnableCpuMemArena { - return MakeSessionOptionWithCudaProvider(0); + get + { + return _enableCpuMemArena; + } + set + { + if (!_enableCpuMemArena && value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableCpuMemArena(_nativePtr)); + _enableCpuMemArena = true; + } + else if (_enableCpuMemArena && !value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableCpuMemArena(_nativePtr)); + _enableCpuMemArena = false; + } + } } + private bool _enableCpuMemArena = true; + /// - /// A helper method to constuct a SessionOptions object for CUDA execution + /// Log Id to be used for the session. Default is empty string. + /// TODO: Should it be named LogTag as in RunOptions? /// - /// - /// A SessionsOptions() object configured for execution on deviceId - public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId=0) + public string LogId { - CheckLibcVersionGreaterThanMinimum(); - CheckCudaExecutionProviderDLLs(); - SessionOptions options = new SessionOptions(); - NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options._nativePtr, deviceId); - NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1); - return options; + get + { + return _logId; + } + + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionLogId(_nativePtr, value)); + _logId = value; + } + } + private string _logId = ""; + + + /// + /// Log Verbosity Level for the session logs. Default = LogLevel.Verbose + /// + public LogLevel LogVerbosityLevel + { + get + { + return _logVerbosityLevel; + } + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionLogVerbosityLevel(_nativePtr, value)); + _logVerbosityLevel = value; + } + } + private LogLevel _logVerbosityLevel = LogLevel.Verbose; + + + /// + /// Threadpool size for the session.Run() calls. + /// Default = 0, meaning threadpool size is aumatically selected from number of available cores. + /// + public int ThreadPoolSize + { + get + { + return _threadPoolSize; + } + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionThreadPoolSize(_nativePtr, value)); + _threadPoolSize = value; + } } + private int _threadPoolSize = 0; // set to what is set in C++ SessionOptions by default; + + + /// + /// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_BASIC. + /// + public GraphOptimizationLevel GraphOptimizationLevel + { + get + { + return _graphOptimizationLevel; + } + set + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, value)); + _graphOptimizationLevel = value; + } + } + private GraphOptimizationLevel _graphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC; + + #endregion + + #region Private Methods + // Declared, but called only if OS = Windows. [DllImport("kernel32.dll")] @@ -130,45 +317,21 @@ private static bool CheckCudaExecutionProviderDLLs() { IntPtr handle = LoadLibrary(dll); if (handle != IntPtr.Zero) - continue; + continue; var sysdir = new StringBuilder(String.Empty, 2048); GetSystemDirectory(sysdir, (uint)sysdir.Capacity); throw new OnnxRuntimeException( - ErrorCode.NoSuchFile, + ErrorCode.NoSuchFile, $"kernel32.LoadLibrary():'{dll}' not found. CUDA is required for GPU execution. " + $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder." - ); + ); } - } + } return true; } - [DllImport("libc", ExactSpelling = true, CallingConvention = CallingConvention.Cdecl)] - private static extern IntPtr gnu_get_libc_version(); - - private static void CheckLibcVersionGreaterThanMinimum() - { - // require libc version 2.23 or higher - var minVersion = new Version(2, 23); - var curVersion = new Version(0, 0); - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - try - { - curVersion = Version.Parse(Marshal.PtrToStringAnsi(gnu_get_libc_version())); - if (curVersion >= minVersion) - return; - } - catch (Exception) - { - // trap any obscure exception - } - throw new OnnxRuntimeException(ErrorCode.RuntimeException, - $"libc.so version={curVersion} does not meet the minimun of 2.23 required by OnnxRuntime. " + - "Linux distribution should be similar to Ubuntu 16.04 or higher"); - } - } + #endregion #region destructors disposers ~SessionOptions() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayTensorExtensions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayTensorExtensions.cs new file mode 100644 index 0000000000000..5189ddf71e300 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayTensorExtensions.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayTensorExtensions.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + public static class ArrayTensorExtensions + { + /// + /// Creates a copy of this single-dimensional array as a DenseTensor<T> + /// + /// Type contained in the array to copy to the DenseTensor<T>. + /// The array to create a DenseTensor<T> from. + /// A 1-dimensional DenseTensor<T> with the same length and content as . + public static DenseTensor ToTensor(this T[] array) + { + return new DenseTensor(array); + } + + /// + /// Creates a copy of this two-dimensional array as a DenseTensor<T> + /// + /// Type contained in the array to copy to the DenseTensor<T>. + /// The array to create a DenseTensor<T> from. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): row-major. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): column-major. + /// A 2-dimensional DenseTensor<T> with the same dimensions and content as . + public static DenseTensor ToTensor(this T[,] array, bool reverseStride = false) + { + return new DenseTensor(array, reverseStride); + } + + /// + /// Creates a copy of this three-dimensional array as a DenseTensor<T> + /// + /// Type contained in the array to copy to the DenseTensor<T>. + /// The array to create a DenseTensor<T> from. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + /// A 3-dimensional DenseTensor<T> with the same dimensions and content as . + public static DenseTensor ToTensor(this T[,,] array, bool reverseStride = false) + { + return new DenseTensor(array, reverseStride); + } + + /// + /// Creates a copy of this n-dimensional array as a DenseTensor<T> + /// + /// Type contained in the array to copy to the DenseTensor<T>. + /// The array to create a DenseTensor<T> from. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + /// A n-dimensional DenseTensor<T> with the same dimensions and content as . + public static DenseTensor ToTensor(this Array array, bool reverseStride = false) + { + return new DenseTensor(array, reverseStride); + } + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.cs new file mode 100644 index 0000000000000..2913799968930 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.cs @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayUtilities.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + internal static class ArrayUtilities + { + public const int StackallocMax = 16; + + public static long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) + { + if (dimensions.Length == 0) + { + return 0; + } + + long product = 1; + for (int i = startIndex; i < dimensions.Length; i++) + { + if (dimensions[i] < 0) + { + throw new ArgumentOutOfRangeException($"{nameof(dimensions)}[{i}]"); + } + + // we use a long which should be much larger than is ever used here, + // but still force checked + checked + { + product *= dimensions[i]; + } + } + + return product; + } + + public static bool IsAscending(ReadOnlySpan values) + { + for (int i = 1; i < values.Length; i++) + { + if (values[i] < values[i - 1]) + { + return false; + } + } + + return true; + } + + public static bool IsDescending(ReadOnlySpan values) + { + for (int i = 1; i < values.Length; i++) + { + if (values[i] > values[i - 1]) + { + return false; + } + } + + return true; + } + + /// + /// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout + /// + /// + /// + /// + public static int[] GetStrides(ReadOnlySpan dimensions, bool reverseStride = false) + { + int[] strides = new int[dimensions.Length]; + + int stride = 1; + if (reverseStride) + { + for (int i = 0; i < strides.Length; i++) + { + strides[i] = stride; + stride *= dimensions[i]; + } + } + else + { + for (int i = strides.Length - 1; i >= 0; i--) + { + strides[i] = stride; + stride *= dimensions[i]; + } + } + + return strides; + } + + public static void SplitStrides(int[] strides, int[] splitAxes, int[] newStrides, int stridesOffset, int[] splitStrides, int splitStridesOffset) + { + int newStrideIndex = 0; + for (int i = 0; i < strides.Length; i++) + { + int stride = strides[i]; + bool isSplit = false; + for (int j = 0; j < splitAxes.Length; j++) + { + if (splitAxes[j] == i) + { + splitStrides[splitStridesOffset + j] = stride; + isSplit = true; + break; + } + } + + if (!isSplit) + { + newStrides[stridesOffset + newStrideIndex++] = stride; + } + } + } + + /// + /// Calculates the 1-d index for n-d indices in layout specified by strides. + /// + /// + /// + /// + /// + public static int GetIndex(int[] strides, ReadOnlySpan indices, int startFromDimension = 0) + { + Debug.Assert(strides.Length == indices.Length); + + int index = 0; + for (int i = startFromDimension; i < indices.Length; i++) + { + index += strides[i] * indices[i]; + } + + return index; + } + + /// + /// Calculates the n-d indices from the 1-d index in a layout specificed by strides + /// + /// + /// + /// + /// + /// + public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, int[] indices, int startFromDimension = 0) + { + Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); + Debug.Assert(strides.Length == indices.Length); + + int remainder = index; + for (int i = startFromDimension; i < strides.Length; i++) + { + // reverse the index for reverseStride so that we divide by largest stride first + var nIndex = reverseStride ? strides.Length - 1 - i : i; + + var stride = strides[nIndex]; + indices[nIndex] = remainder / stride; + remainder %= stride; + } + } + + /// + /// Calculates the n-d indices from the 1-d index in a layout specificed by strides + /// + /// + /// + /// + /// + /// + public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, Span indices, int startFromDimension = 0) + { + Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); + Debug.Assert(strides.Length == indices.Length); + + int remainder = index; + for (int i = startFromDimension; i < strides.Length; i++) + { + // reverse the index for reverseStride so that we divide by largest stride first + var nIndex = reverseStride ? strides.Length - 1 - i : i; + + var stride = strides[nIndex]; + indices[nIndex] = remainder / stride; + remainder %= stride; + } + } + + /// + /// Takes an 1-d index over n-d sourceStrides and recalculates it assuming same n-d coordinates over a different n-d strides + /// + public static int TransformIndexByStrides(int index, int[] sourceStrides, bool sourceReverseStride, int[] transformStrides) + { + Debug.Assert(index >= 0); + Debug.Assert(sourceReverseStride ? IsAscending(sourceStrides) : IsDescending(sourceStrides), "Index decomposition requires ordered strides"); + Debug.Assert(sourceStrides.Length == transformStrides.Length); + + int transformIndex = 0; + int remainder = index; + + for (int i = 0; i < sourceStrides.Length; i++) + { + // reverse the index for reverseStride so that we divide by largest stride first + var nIndex = sourceReverseStride ? sourceStrides.Length - 1 - i: i; + + var sourceStride = sourceStrides[nIndex]; + var transformStride = transformStrides[nIndex]; + + transformIndex += transformStride * (remainder / sourceStride); + remainder %= sourceStride; + } + + return transformIndex; + } + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/DenseTensor.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/DenseTensor.cs new file mode 100644 index 0000000000000..efa193a42b1af --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/DenseTensor.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/DenseTensor.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + /// + /// Represents a multi-dimensional collection of objects of type T that can be accessed by indices. DenseTensor stores values in a contiguous sequential block of memory where all values are represented. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + public class DenseTensor : Tensor + { + private readonly Memory memory; + + internal DenseTensor(Array fromArray, bool reverseStride = false) : base(fromArray, reverseStride) + { + // copy initial array + var backingArray = new T[fromArray.Length]; + + int index = 0; + if (reverseStride) + { + // Array is always row-major + var sourceStrides = ArrayUtilities.GetStrides(dimensions); + + foreach (var item in fromArray) + { + var destIndex = ArrayUtilities.TransformIndexByStrides(index++, sourceStrides, false, strides); + backingArray[destIndex] = (T)item; + } + } + else + { + foreach (var item in fromArray) + { + backingArray[index++] = (T)item; + } + } + memory = backingArray; + } + + /// + /// Initializes a rank-1 Tensor using the specified . + /// + /// Size of the 1-dimensional tensor + public DenseTensor(int length) : base(length) + { + memory = new T[length]; + } + + /// + /// Initializes a rank-n Tensor using the dimensions specified in . + /// + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + public DenseTensor(ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) + { + memory = new T[Length]; + } + + /// + /// Constructs a new DenseTensor of the specifed dimensions, wrapping existing backing memory for the contents. + /// + /// + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + public DenseTensor(Memory memory, ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) + { + this.memory = memory; + + if (Length != memory.Length) + { + throw new ArgumentException($"Length of {nameof(memory)} ({memory.Length}) must match product of {nameof(dimensions)} ({Length})."); + } + } + + /// + /// Memory storing backing values of this tensor. + /// + public Memory Buffer => memory; + + /// + /// Gets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// + /// An integer index computed as a dot-product of indices. + /// The value at the specified position in this Tensor. + public override T GetValue(int index) + { + return Buffer.Span[index]; + } + + /// + /// Sets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// + /// An integer index computed as a dot-product of indices. + /// The new value to set at the specified position in this Tensor. + public override void SetValue(int index, T value) + { + Buffer.Span[index] = value; + } + + protected override void CopyTo(T[] array, int arrayIndex) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + if (array.Length < arrayIndex + Length) + { + throw new ArgumentException("The number of elements in the Tensor is greater than the available space from index to the end of the destination array.", nameof(array)); + } + + Buffer.Span.CopyTo(array.AsSpan(arrayIndex)); + } + + protected override int IndexOf(T item) + { + // TODO: use Span.IndexOf when/if it removes the IEquatable type constraint + if (MemoryMarshal.TryGetArray(Buffer, out var arraySegment)) + { + var result = Array.IndexOf(arraySegment.Array, item, arraySegment.Offset, arraySegment.Count); + if (result != -1) + { + result -= arraySegment.Offset; + } + return result; + } + else + { + return base.IndexOf(item); + } + } + + /// + /// Creates a shallow copy of this tensor, with new backing storage. + /// + /// A shallow copy of this tensor. + public override Tensor Clone() + { + return new DenseTensor(Buffer.ToArray(), dimensions, IsReversedStride); + } + + /// + /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. + /// + /// Type contained in the returned Tensor. + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// A new tensor with the same layout as this tensor but different type and dimensions. + public override Tensor CloneEmpty(ReadOnlySpan dimensions) + { + return new DenseTensor(dimensions, IsReversedStride); + } + + /// + /// Reshapes the current tensor to new dimensions, using the same backing storage. + /// + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// A new tensor that reinterprets backing Buffer of this tensor with different dimensions. + public override Tensor Reshape(ReadOnlySpan dimensions) + { + if (dimensions.Length == 0) + { + throw new ArgumentException("Dimensions must contain elements.", nameof(dimensions)); + } + + var newSize = ArrayUtilities.GetProduct(dimensions); + + if (newSize != Length) + { + throw new ArgumentException($"Cannot reshape array due to mismatch in lengths, currently {Length} would become {newSize}.", nameof(dimensions)); + } + + return new DenseTensor(Buffer, dimensions, IsReversedStride); + } + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs new file mode 100644 index 0000000000000..a8eb8b9a27c3e --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs @@ -0,0 +1,1311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System; +using System.Runtime.CompilerServices; + +// Making this assembly's internals visible to the internal Test assembly +[assembly: InternalsVisibleTo("Microsoft.ML.OnnxRuntime.Tests," + + "PublicKey=002400000480000094000000060200000024000052534131000400000100010059013e94e4bc70" + + "136ca4c35f33acd6b62974536b698f9c7a21cee18d805c7ad860ad9eebfdc47a96ba2f8d03f4cf" + + "1c36b9d30787e276c7b9833b5bf2a6eba7e919e6b90083078a352262aed1d842e5f70a3085cbcf" + + "4c56ae851b161137920961c23fcc246598d61d258ccc615c927b2441359eea666a99ce1c3c07dc" + + "a18fb0e1")] + + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + /// + /// Various methods for creating and manipulating Tensor<T> + /// + public static partial class Tensor + { + /// + /// Creates an identity tensor of the specified size. An identity tensor is a two dimensional tensor with 1s in the diagonal. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + /// Width and height of the identity tensor to create. + /// a by with 1s along the diagonal and zeros elsewhere. + public static Tensor CreateIdentity(int size) + { + return CreateIdentity(size, false, Tensor.One); + } + + /// + /// Creates an identity tensor of the specified size and layout (row vs column major). An identity tensor is a two dimensional tensor with 1s in the diagonal. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + /// Width and height of the identity tensor to create. + /// >False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major. + /// a by with 1s along the diagonal and zeros elsewhere. + public static Tensor CreateIdentity(int size, bool columMajor) + { + return CreateIdentity(size, columMajor, Tensor.One); + } + + /// + /// Creates an identity tensor of the specified size and layout (row vs column major) using the specified one value. An identity tensor is a two dimensional tensor with 1s in the diagonal. This may be used in case T is a type that doesn't have a known 1 value. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + /// Width and height of the identity tensor to create. + /// >False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major. + /// Value of that is used along the diagonal. + /// a by with 1s along the diagonal and zeros elsewhere. + public static Tensor CreateIdentity(int size, bool columMajor, T oneValue) + { + unsafe + { + Span dimensions = stackalloc int[2]; + dimensions[0] = dimensions[1] = size; + + var result = new DenseTensor(dimensions, columMajor); + + for (int i = 0; i < size; i++) + { + result.SetValue(i * size + i, oneValue); + } + + return result; + } + } + + /// + /// Creates a n+1-rank tensor using the specified n-rank diagonal. Values not on the diagonal will be filled with zeros. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + /// Tensor representing the diagonal to build the new tensor from. + /// A new tensor of the same layout and order as of one higher rank, with the values of along the diagonal and zeros elsewhere. + public static Tensor CreateFromDiagonal(Tensor diagonal) + { + return CreateFromDiagonal(diagonal, 0); + } + + /// + /// Creates a n+1-dimension tensor using the specified n-dimension diagonal at the specified offset from the center. Values not on the diagonal will be filled with zeros. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + /// Tensor representing the diagonal to build the new tensor from. + /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. + /// A new tensor of the same layout and order as of one higher rank, with the values of along the specified diagonal and zeros elsewhere. + public static Tensor CreateFromDiagonal(Tensor diagonal, int offset) + { + if (diagonal.Rank < 1) + { + throw new ArgumentException($"Tensor {nameof(diagonal)} must have at least one dimension.", nameof(diagonal)); + } + + int diagonalLength = diagonal.dimensions[0]; + + // TODO: allow specification of axis1 and axis2? + var rank = diagonal.dimensions.Length + 1; + Span dimensions = rank < ArrayUtilities.StackallocMax ? stackalloc int[rank] : new int[rank]; + + // assume square + var axisLength = diagonalLength + Math.Abs(offset); + dimensions[0] = dimensions[1] = axisLength; + + for (int i = 1; i < diagonal.dimensions.Length; i++) + { + dimensions[i + 1] = diagonal.dimensions[i]; + } + + var result = diagonal.CloneEmpty(dimensions); + + var sizePerDiagonal = diagonal.Length / diagonalLength; + + var diagProjectionStride = diagonal.IsReversedStride && diagonal.Rank > 1 ? diagonal.strides[1] : 1; + var resultProjectionStride = result.IsReversedStride && result.Rank > 2 ? result.strides[2] : 1; + + for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) + { + var resultIndex0 = offset < 0 ? diagIndex - offset : diagIndex; + var resultIndex1 = offset > 0 ? diagIndex + offset : diagIndex; + + var resultBase = resultIndex0 * result.strides[0] + resultIndex1 * result.strides[1]; + var diagBase = diagIndex * diagonal.strides[0]; + + for (int diagProjectionOffset = 0; diagProjectionOffset < sizePerDiagonal; diagProjectionOffset++) + { + result.SetValue(resultBase + diagProjectionOffset * resultProjectionStride, + diagonal.GetValue(diagBase + diagProjectionOffset * diagProjectionStride)); + } + } + + return result; + } + } + + /// + /// Represents a multi-dimensional collection of objects of type T that can be accessed by indices. + /// + /// type contained within the Tensor. Typically a value type such as int, double, float, etc. + [DebuggerDisplay("{GetArrayString(false)}")] + // When we cross-compile for frameworks that expose ICloneable this must implement ICloneable as well. + public abstract class Tensor : IList, IList, IReadOnlyList, IStructuralComparable, IStructuralEquatable + { + internal static T Zero + { + get + { + if (typeof(T) == typeof(bool)) + { + return (T)(object)(false); + } + else if (typeof(T) == typeof(byte)) + { + return (T)(object)(byte)(0); + } + else if (typeof(T) == typeof(char)) + { + return (T)(object)(char)(0); + } + else if (typeof(T) == typeof(decimal)) + { + return (T)(object)(decimal)(0); + } + else if (typeof(T) == typeof(double)) + { + return (T)(object)(double)(0); + } + else if (typeof(T) == typeof(float)) + { + return (T)(object)(float)(0); + } + else if (typeof(T) == typeof(int)) + { + return (T)(object)(int)(0); + } + else if (typeof(T) == typeof(long)) + { + return (T)(object)(long)(0); + } + else if (typeof(T) == typeof(sbyte)) + { + return (T)(object)(sbyte)(0); + } + else if (typeof(T) == typeof(short)) + { + return (T)(object)(short)(0); + } + else if (typeof(T) == typeof(uint)) + { + return (T)(object)(uint)(0); + } + else if (typeof(T) == typeof(ulong)) + { + return (T)(object)(ulong)(0); + } + else if (typeof(T) == typeof(ushort)) + { + return (T)(object)(ushort)(0); + } + + throw new NotSupportedException(); + } + } + + internal static T One + { + get + { + if (typeof(T) == typeof(bool)) + { + return (T)(object)(true); + } + else if (typeof(T) == typeof(byte)) + { + return (T)(object)(byte)(1); + } + else if (typeof(T) == typeof(char)) + { + return (T)(object)(char)(1); + } + else if (typeof(T) == typeof(decimal)) + { + return (T)(object)(decimal)(1); + } + else if (typeof(T) == typeof(double)) + { + return (T)(object)(double)(1); + } + else if (typeof(T) == typeof(float)) + { + return (T)(object)(float)(1); + } + else if (typeof(T) == typeof(int)) + { + return (T)(object)(int)(1); + } + else if (typeof(T) == typeof(long)) + { + return (T)(object)(long)(1); + } + else if (typeof(T) == typeof(sbyte)) + { + return (T)(object)(sbyte)(1); + } + else if (typeof(T) == typeof(short)) + { + return (T)(object)(short)(1); + } + else if (typeof(T) == typeof(uint)) + { + return (T)(object)(uint)(1); + } + else if (typeof(T) == typeof(ulong)) + { + return (T)(object)(ulong)(1); + } + else if (typeof(T) == typeof(ushort)) + { + return (T)(object)(ushort)(1); + } + + throw new NotSupportedException(); + } + } + + internal readonly int[] dimensions; + internal readonly int[] strides; + private readonly bool isReversedStride; + + private readonly long length; + + /// + /// Initialize a 1-dimensional tensor of the specified length + /// + /// Size of the 1-dimensional tensor + protected Tensor(int length) + { + dimensions = new[] { length }; + strides = new[] { 1 }; + isReversedStride = false; + this.length = length; + } + + /// + /// Initialize an n-dimensional tensor with the specified dimensions and layout. ReverseStride=true gives a stride of 1-element witdth to the first dimension (0). ReverseStride=false gives a stride of 1-element width to the last dimension (n-1). + /// + /// An span of integers that represent the size of each dimension of the Tensor to create. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + protected Tensor(ReadOnlySpan dimensions, bool reverseStride) + { + if (dimensions.Length == 0) + { + throw new ArgumentException("Dimensions must contain elements.", nameof(dimensions)); + } + + this.dimensions = new int[dimensions.Length]; + long size = 1; + for (int i = 0; i < dimensions.Length; i++) + { + if (dimensions[i] < 1) + { + throw new ArgumentOutOfRangeException(nameof(dimensions), "Dimensions must be positive and non-zero"); + } + this.dimensions[i] = dimensions[i]; + size *= dimensions[i]; + } + + strides = ArrayUtilities.GetStrides(dimensions, reverseStride); + isReversedStride = reverseStride; + + length = size; + } + + /// + /// Initializes tensor with same dimensions as array, content of array is ignored. ReverseStride=true gives a stride of 1-element witdth to the first dimension (0). ReverseStride=false gives a stride of 1-element width to the last dimension (n-1). + /// + /// Array from which to derive dimensions. + /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. + protected Tensor(Array fromArray, bool reverseStride) + { + if (fromArray == null) + { + throw new ArgumentNullException(nameof(fromArray)); + } + + if (fromArray.Rank == 0) + { + throw new ArgumentException("Array must contain elements.", nameof(fromArray)); + } + + dimensions = new int[fromArray.Rank]; + long size = 1; + for (int i = 0; i < dimensions.Length; i++) + { + dimensions[i] = fromArray.GetLength(i); + size *= dimensions[i]; + } + + strides = ArrayUtilities.GetStrides(dimensions, reverseStride); + isReversedStride = reverseStride; + + length = size; + } + + /// + /// Total length of the Tensor. + /// + public long Length => length; + + /// + /// Rank of the tensor: number of dimensions. + /// + public int Rank => dimensions.Length; + + /// + /// True if strides are reversed (AKA Column-major) + /// + public bool IsReversedStride => isReversedStride; + + /// + /// Returns a readonly view of the dimensions of this tensor. + /// + public ReadOnlySpan Dimensions => dimensions; + + /// + /// Returns a readonly view of the strides of this tensor. + /// + public ReadOnlySpan Strides => strides; + + /// + /// Sets all elements in Tensor to . + /// + /// Value to fill + public virtual void Fill(T value) + { + for (int i = 0; i < Length; i++) + { + SetValue(i, value); + } + } + + /// + /// Creates a shallow copy of this tensor, with new backing storage. + /// + /// A shallow copy of this tensor. + public abstract Tensor Clone(); + + /// + /// Creates a new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value. + /// + /// A new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value. + public virtual Tensor CloneEmpty() + { + return CloneEmpty(dimensions); + } + + /// + /// Creates a new Tensor with the specified dimensions and the same layout as this tensor with elements initialized to their default value. + /// + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// A new Tensor with the same layout as this tensor and specified with elements initialized to their default value. + public virtual Tensor CloneEmpty(ReadOnlySpan dimensions) + { + return CloneEmpty(dimensions); + } + + /// + /// Creates a new Tensor of a different type with the same layout and size as this tensor with elements initialized to their default value. + /// + /// Type contained within the new Tensor. Typically a value type such as int, double, float, etc. + /// A new Tensor with the same layout and dimensions as this tensor with elements of type initialized to their default value. + public virtual Tensor CloneEmpty() + { + return CloneEmpty(dimensions); + } + + /// + /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. + /// + /// Type contained within the new Tensor. Typically a value type such as int, double, float, etc. + /// An span of integers that represent the size of each dimension of the DenseTensor to create. + /// A new Tensor with the same layout as this tensor of specified with elements of type initialized to their default value. + public abstract Tensor CloneEmpty(ReadOnlySpan dimensions); + + /// + /// Gets the n-1 dimension diagonal from the n dimension tensor. + /// + /// An n-1 dimension tensor with the values from the main diagonal of this tensor. + public Tensor GetDiagonal() + { + return GetDiagonal(0); + } + + /// + /// Gets the n-1 dimension diagonal from the n dimension tensor at the specified offset from center. + /// + /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. + /// An n-1 dimension tensor with the values from the specified diagonal of this tensor. + public Tensor GetDiagonal(int offset) + { + // Get diagonal of first two dimensions for all remaining dimensions + + // diagnonal is as follows: + // { 1, 2, 4 } + // { 8, 3, 9 } + // { 0, 7, 5 } + // The diagonal at offset 0 is { 1, 3, 5 } + // The diagonal at offset 1 is { 2, 9 } + // The diagonal at offset -1 is { 8, 7 } + + if (Rank < 2) + { + throw new InvalidOperationException($"Cannot compute diagonal of {nameof(Tensor)} with Rank less than 2."); + } + + // TODO: allow specification of axis1 and axis2? + var axisLength0 = dimensions[0]; + var axisLength1 = dimensions[1]; + + // the diagonal will be the length of the smaller axis + // if offset it positive, the length will shift along the second axis + // if the offsett is negative, the length will shift along the first axis + // In that way the length of the diagonal will be + // Min(offset < 0 ? axisLength0 + offset : axisLength0, offset > 0 ? axisLength1 - offset : axisLength1) + // To illustrate, consider the following + // { 1, 2, 4, 3, 7 } + // { 8, 3, 9, 2, 6 } + // { 0, 7, 5, 2, 9 } + // The diagonal at offset 0 is { 1, 3, 5 }, Min(3, 5) = 3 + // The diagonal at offset 1 is { 2, 9, 2 }, Min(3, 5 - 1) = 3 + // The diagonal at offset 3 is { 3, 6 }, Min(3, 5 - 3) = 2 + // The diagonal at offset -1 is { 8, 7 }, Min(3 - 1, 5) = 2 + var offsetAxisLength0 = offset < 0 ? axisLength0 + offset : axisLength0; + var offsetAxisLength1 = offset > 0 ? axisLength1 - offset : axisLength1; + + var diagonalLength = Math.Min(offsetAxisLength0, offsetAxisLength1); + + if (diagonalLength <= 0) + { + throw new ArgumentException($"Cannot compute diagonal with offset {offset}", nameof(offset)); + } + + var newTensorRank = Rank - 1; + var newTensorDimensions = newTensorRank < ArrayUtilities.StackallocMax ? stackalloc int[newTensorRank] : new int[newTensorRank]; + newTensorDimensions[0] = diagonalLength; + + for (int i = 2; i < dimensions.Length; i++) + { + newTensorDimensions[i - 1] = dimensions[i]; + } + + var diagonalTensor = CloneEmpty(newTensorDimensions); + var sizePerDiagonal = diagonalTensor.Length / diagonalTensor.Dimensions[0]; + + var diagProjectionStride = diagonalTensor.IsReversedStride && diagonalTensor.Rank > 1 ? diagonalTensor.strides[1] : 1; + var sourceProjectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1; + + for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) + { + var sourceIndex0 = offset < 0 ? diagIndex - offset : diagIndex; + var sourceIndex1 = offset > 0 ? diagIndex + offset : diagIndex; + + var sourceBase = sourceIndex0 * strides[0] + sourceIndex1 * strides[1]; + var diagBase = diagIndex * diagonalTensor.strides[0]; + + for (int diagProjectionIndex = 0; diagProjectionIndex < sizePerDiagonal; diagProjectionIndex++) + { + diagonalTensor.SetValue(diagBase + diagProjectionIndex * diagProjectionStride, + GetValue(sourceBase + diagProjectionIndex * sourceProjectionStride)); + } + } + + return diagonalTensor; + } + + /// + /// Gets a tensor representing the elements below and including the diagonal, with the rest of the elements zero-ed. + /// + /// A tensor with the values from this tensor at and below the main diagonal and zeros elsewhere. + public Tensor GetTriangle() + { + return GetTriangle(0, upper: false); + } + + /// + /// Gets a tensor representing the elements below and including the specified diagonal, with the rest of the elements zero-ed. + /// + /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. + /// A tensor with the values from this tensor at and below the specified diagonal and zeros elsewhere. + public Tensor GetTriangle(int offset) + { + return GetTriangle(offset, upper: false); + } + + /// + /// Gets a tensor representing the elements above and including the diagonal, with the rest of the elements zero-ed. + /// + /// A tensor with the values from this tensor at and above the main diagonal and zeros elsewhere. + public Tensor GetUpperTriangle() + { + return GetTriangle(0, upper: true); + } + + /// + /// Gets a tensor representing the elements above and including the specified diagonal, with the rest of the elements zero-ed. + /// + /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. + /// A tensor with the values from this tensor at and above the specified diagonal and zeros elsewhere. + public Tensor GetUpperTriangle(int offset) + { + return GetTriangle(offset, upper: true); + } + + public Tensor GetTriangle(int offset, bool upper) + { + if (Rank < 2) + { + throw new InvalidOperationException($"Cannot compute triangle of {nameof(Tensor)} with Rank less than 2."); + } + + // Similar to get diagonal except it gets every element below and including the diagonal. + + // TODO: allow specification of axis1 and axis2? + var axisLength0 = dimensions[0]; + var axisLength1 = dimensions[1]; + var diagonalLength = Math.Max(axisLength0, axisLength1); + + var result = CloneEmpty(); + + var projectionSize = Length / (axisLength0 * axisLength1); + var projectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1; + + for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) + { + // starting point for the tri + var triIndex0 = offset > 0 ? diagIndex - offset : diagIndex; + var triIndex1 = offset > 0 ? diagIndex : diagIndex + offset; + + // for lower triangle, iterate index0 keeping same index1 + // for upper triangle, iterate index1 keeping same index0 + + if (triIndex0 < 0) + { + if (upper) + { + // out of bounds, ignore this diagIndex. + continue; + } + else + { + // set index to 0 so that we can iterate on the remaining index0 values. + triIndex0 = 0; + } + } + + if (triIndex1 < 0) + { + if (upper) + { + // set index to 0 so that we can iterate on the remaining index1 values. + triIndex1 = 0; + } + else + { + // out of bounds, ignore this diagIndex. + continue; + } + } + + while ((triIndex1 < axisLength1) && (triIndex0 < axisLength0)) + { + var baseIndex = triIndex0 * strides[0] + triIndex1 * result.strides[1]; + + for (int projectionIndex = 0; projectionIndex < projectionSize; projectionIndex++) + { + var index = baseIndex + projectionIndex * projectionStride; + + result.SetValue(index, GetValue(index)); + } + + if (upper) + { + triIndex1++; + } + else + { + triIndex0++; + } + } + } + + return result; + } + + /// + /// Reshapes the current tensor to new dimensions, using the same backing storage if possible. + /// + /// An span of integers that represent the size of each dimension of the Tensor to create. + /// A new tensor that reinterprets this tensor with different dimensions. + public abstract Tensor Reshape(ReadOnlySpan dimensions); + + /// + /// Obtains the value at the specified indices + /// + /// A one-dimensional array of integers that represent the indices specifying the position of the element to get. + /// The value at the specified position in this Tensor. + public virtual T this[params int[] indices] + { + get + { + if (indices == null) + { + throw new ArgumentNullException(nameof(indices)); + } + var span = new ReadOnlySpan(indices); + return this[span]; + } + + set + { + if (indices == null) + { + throw new ArgumentNullException(nameof(indices)); + } + var span = new ReadOnlySpan(indices); + this[span] = value; + } + } + + /// + /// Obtains the value at the specified indices + /// + /// A span integers that represent the indices specifying the position of the element to get. + /// The value at the specified position in this Tensor. + public virtual T this[ReadOnlySpan indices] + { + get + { + return GetValue(ArrayUtilities.GetIndex(strides, indices)); + } + + set + { + SetValue(ArrayUtilities.GetIndex(strides, indices), value); + } + } + + /// + /// Gets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// + /// An integer index computed as a dot-product of indices. + /// The value at the specified position in this Tensor. + public abstract T GetValue(int index); + + /// + /// Sets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// + /// An integer index computed as a dot-product of indices. + /// The new value to set at the specified position in this Tensor. + public abstract void SetValue(int index, T value); + + + #region statics + /// + /// Performs a value comparison of the content and shape of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. If not equal a tensor is greater or less than another tensor based on the first non-equal element when enumerating in linear order. + /// + /// + /// + /// + public static int Compare(Tensor left, Tensor right) + { + return StructuralComparisons.StructuralComparer.Compare(left, right); + } + + /// + /// Performs a value equality comparison of the content of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. + /// + /// + /// + /// + public static bool Equals(Tensor left, Tensor right) + { + return StructuralComparisons.StructuralEqualityComparer.Equals(left, right); + } + #endregion + + #region IEnumerable members + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)this).GetEnumerator(); + } + #endregion + + #region ICollection members + int ICollection.Count => (int)Length; + + bool ICollection.IsSynchronized => false; + + object ICollection.SyncRoot => this; // backingArray.this? + + void ICollection.CopyTo(Array array, int index) + { + if (array is T[] destinationArray) + { + CopyTo(destinationArray, index); + } + else + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + if (array.Rank != 1) + { + throw new ArgumentException("Only single dimensional arrays are supported for the requested action.", nameof(array)); + } + if (array.Length < index + Length) + { + throw new ArgumentException("The number of elements in the Tensor is greater than the available space from index to the end of the destination array.", nameof(array)); + } + + for (int i = 0; i < length; i++) + { + array.SetValue(GetValue(i), index + i); + } + } + } + #endregion + + #region IList members + object IList.this[int index] + { + get + { + return GetValue(index); + } + set + { + try + { + SetValue(index, (T)value); + } + catch (InvalidCastException) + { + throw new ArgumentException($"The value \"{value}\" is not of type \"{typeof(T)}\" and cannot be used in this generic collection."); + } + } + } + + public bool IsFixedSize => true; + + public bool IsReadOnly => false; + + int IList.Add(object value) + { + throw new InvalidOperationException(); + } + + void IList.Clear() + { + Fill(default(T)); + } + + bool IList.Contains(object value) + { + if (IsCompatibleObject(value)) + { + return Contains((T)value); + } + return false; + } + + int IList.IndexOf(object value) + { + if (IsCompatibleObject(value)) + { + return IndexOf((T)value); + } + return -1; + } + + void IList.Insert(int index, object value) + { + throw new InvalidOperationException(); + } + + void IList.Remove(object value) + { + throw new InvalidOperationException(); + } + + void IList.RemoveAt(int index) + { + throw new InvalidOperationException(); + } + #endregion + + #region IEnumerable members + IEnumerator IEnumerable.GetEnumerator() + { + for (int i = 0; i < Length; i++) + { + yield return GetValue(i); + } + } + #endregion + + #region ICollection members + int ICollection.Count => (int)Length; + + void ICollection.Add(T item) + { + throw new InvalidOperationException(); + } + + void ICollection.Clear() + { + Fill(default(T)); + } + + bool ICollection.Contains(T item) + { + return Contains(item); + } + + /// + /// Determines whether an element is in the Tensor<T>. + /// + /// + /// The object to locate in the Tensor<T>. The value can be null for reference types. + /// + /// + /// true if item is found in the Tensor<T>; otherwise, false. + /// + protected virtual bool Contains(T item) + { + return Length != 0 && IndexOf(item) != -1; + } + + void ICollection.CopyTo(T[] array, int arrayIndex) + { + CopyTo(array, arrayIndex); + } + + /// + /// Copies the elements of the Tensor<T> to an Array, starting at a particular Array index. + /// + /// + /// The one-dimensional Array that is the destination of the elements copied from Tensor<T>. The Array must have zero-based indexing. + /// + /// + /// The zero-based index in array at which copying begins. + /// + protected virtual void CopyTo(T[] array, int arrayIndex) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + if (array.Length < arrayIndex + Length) + { + throw new ArgumentException("The number of elements in the Tensor is greater than the available space from index to the end of the destination array.", nameof(array)); + } + + for (int i = 0; i < length; i++) + { + array[arrayIndex + i] = GetValue(i); + } + } + + bool ICollection.Remove(T item) + { + throw new InvalidOperationException(); + } + #endregion + + #region IReadOnlyCollection members + + int IReadOnlyCollection.Count => (int)Length; + + #endregion + + #region IList members + T IList.this[int index] + { + get { return GetValue(index); } + set { SetValue(index, value); } + } + + int IList.IndexOf(T item) + { + return IndexOf(item); + } + + /// + /// Determines the index of a specific item in the Tensor<T>. + /// + /// The object to locate in the Tensor<T>. + /// The index of item if found in the tensor; otherwise, -1. + protected virtual int IndexOf(T item) + { + for (int i = 0; i < Length; i++) + { + if (GetValue(i).Equals(item)) + { + return i; + } + } + + return -1; + } + + void IList.Insert(int index, T item) + { + throw new InvalidOperationException(); + } + + void IList.RemoveAt(int index) + { + throw new InvalidOperationException(); + } + #endregion + + #region IReadOnlyList members + + T IReadOnlyList.this[int index] => GetValue(index); + + #endregion + + #region IStructuralComparable members + int IStructuralComparable.CompareTo(object other, IComparer comparer) + { + if (other == null) + { + return 1; + } + + if (other is Tensor) + { + return CompareTo((Tensor)other, comparer); + } + + var otherArray = other as Array; + + if (otherArray != null) + { + return CompareTo(otherArray, comparer); + } + + throw new ArgumentException($"Cannot compare {nameof(Tensor)} to {other.GetType()}.", nameof(other)); + } + + private int CompareTo(Tensor other, IComparer comparer) + { + if (Rank != other.Rank) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} with Rank {Rank} to {nameof(other)} with Rank {other.Rank}.", nameof(other)); + } + + for (int i = 0; i < dimensions.Length; i++) + { + if (dimensions[i] != other.dimensions[i]) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)}s with differning dimension {i}, {dimensions[i]} != {other.dimensions[i]}.", nameof(other)); + } + } + + int result = 0; + + if (IsReversedStride == other.IsReversedStride) + { + for (int i = 0; i < Length; i++) + { + result = comparer.Compare(GetValue(i), other.GetValue(i)); + if (result != 0) + { + break; + } + } + } + else + { + var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank]; + for (int i = 0; i < Length; i++) + { + ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); + result = comparer.Compare(this[indices], other[indices]); + if (result != 0) + { + break; + } + } + } + + return result; + } + + private int CompareTo(Array other, IComparer comparer) + { + if (Rank != other.Rank) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} with Rank {Rank} to {nameof(Array)} with rank {other.Rank}.", nameof(other)); + } + + for (int i = 0; i < dimensions.Length; i++) + { + var otherDimension = other.GetLength(i); + if (dimensions[i] != otherDimension) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} to {nameof(Array)} with differning dimension {i}, {dimensions[i]} != {otherDimension}.", nameof(other)); + } + } + + int result = 0; + var indices = new int[Rank]; + for (int i = 0; i < Length; i++) + { + ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); + + result = comparer.Compare(GetValue(i), other.GetValue(indices)); + + if (result != 0) + { + break; + } + } + + return result; + } + #endregion + + #region IStructuralEquatable members + bool IStructuralEquatable.Equals(object other, IEqualityComparer comparer) + { + if (other == null) + { + return false; + } + + if (other is Tensor) + { + return Equals((Tensor)other, comparer); + } + + var otherArray = other as Array; + + if (otherArray != null) + { + return Equals(otherArray, comparer); + } + + throw new ArgumentException($"Cannot compare {nameof(Tensor)} to {other.GetType()}.", nameof(other)); + } + + private bool Equals(Tensor other, IEqualityComparer comparer) + { + if (Rank != other.Rank) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} with Rank {Rank} to {nameof(other)} with Rank {other.Rank}.", nameof(other)); + } + + for (int i = 0; i < dimensions.Length; i++) + { + if (dimensions[i] != other.dimensions[i]) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)}s with differning dimension {i}, {dimensions[i]} != {other.dimensions[i]}.", nameof(other)); + } + } + + if (IsReversedStride == other.IsReversedStride) + { + for (int i = 0; i < Length; i++) + { + if (!comparer.Equals(GetValue(i), other.GetValue(i))) + { + return false; + } + } + } + else + { + var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank]; + for (int i = 0; i < Length; i++) + { + ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); + + if (!comparer.Equals(this[indices], other[indices])) + { + return false; + } + } + } + + return true; + } + + private bool Equals(Array other, IEqualityComparer comparer) + { + if (Rank != other.Rank) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} with Rank {Rank} to {nameof(Array)} with rank {other.Rank}.", nameof(other)); + } + + for (int i = 0; i < dimensions.Length; i++) + { + var otherDimension = other.GetLength(i); + if (dimensions[i] != otherDimension) + { + throw new ArgumentException($"Cannot compare {nameof(Tensor)} to {nameof(Array)} with differning dimension {i}, {dimensions[i]} != {otherDimension}.", nameof(other)); + } + } + + var indices = new int[Rank]; + for (int i = 0; i < Length; i++) + { + ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); + + if (!comparer.Equals(GetValue(i), other.GetValue(indices))) + { + return false; + } + } + + return true; + } + int IStructuralEquatable.GetHashCode(IEqualityComparer comparer) + { + int hashCode = 0; + // this ignores shape, which is fine it just means we'll have hash collisions for things + // with the same content and different shape. + for (int i = 0; i < Length; i++) + { + hashCode ^= comparer.GetHashCode(GetValue(i)); + } + + return hashCode; + } + #endregion + + #region Translations + + /// + /// Creates a copy of this tensor as a DenseTensor<T>. If this tensor is already a DenseTensor<T> calling this method is equivalent to calling Clone(). + /// + /// + public virtual DenseTensor ToDenseTensor() + { + var denseTensor = new DenseTensor(Dimensions, IsReversedStride); + for (int i = 0; i < Length; i++) + { + denseTensor.SetValue(i, GetValue(i)); + } + return denseTensor; + } + + #endregion + + public string GetArrayString(bool includeWhitespace = true) + { + var builder = new StringBuilder(); + + var strides = ArrayUtilities.GetStrides(dimensions); + var indices = new int[Rank]; + var innerDimension = Rank - 1; + var innerLength = dimensions[innerDimension]; + var outerLength = Length / innerLength; + + int indent = 0; + for (int outerIndex = 0; outerIndex < Length; outerIndex += innerLength) + { + ArrayUtilities.GetIndices(strides, false, outerIndex, indices); + + while ((indent < innerDimension) && (indices[indent] == 0)) + { + // start up + if (includeWhitespace) + { + Indent(builder, indent); + } + indent++; + builder.Append('{'); + if (includeWhitespace) + { + builder.AppendLine(); + } + } + + for (int innerIndex = 0; innerIndex < innerLength; innerIndex++) + { + indices[innerDimension] = innerIndex; + + if ((innerIndex == 0)) + { + if (includeWhitespace) + { + Indent(builder, indent); + } + builder.Append('{'); + } + else + { + builder.Append(','); + } + builder.Append(this[indices]); + } + builder.Append('}'); + + for (int i = Rank - 2; i >= 0; i--) + { + var lastIndex = dimensions[i] - 1; + if (indices[i] == lastIndex) + { + // close out + --indent; + if (includeWhitespace) + { + builder.AppendLine(); + Indent(builder, indent); + } + builder.Append('}'); + } + else + { + builder.Append(','); + if (includeWhitespace) + { + builder.AppendLine(); + } + break; + } + } + } + + return builder.ToString(); + } + + private static void Indent(StringBuilder builder, int tabs, int spacesPerTab = 4) + { + for (int tab = 0; tab < tabs; tab++) + { + for (int space = 0; space < spacesPerTab; space++) + { + builder.Append(' '); + } + } + } + + private static bool IsCompatibleObject(object value) + { + // Non-null values are fine. Only accept nulls if T is a class or Nullable. + // Note that default(T) is not equal to null for value types except when T is Nullable. + return ((value is T) || (value == null && default(T) == null)); + } + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp index 34765ab133fa2..559d3690e9664 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp @@ -23,10 +23,11 @@ int main(int argc, char* argv[]) { // Sets graph optimization level // Available levels are - // 0 -> To disable all optimizations - // 1 -> To enable basic optimizations (Such as redundant node removals) - // 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions) - session_options.SetGraphOptimizationLevel(1); + // ORT_DISABLE_ALL -> To disable all optimizations + // ORT_ENABLE_BASIC -> To enable basic optimizations (Such as redundant node removals) + // ORT_ENABLE_EXTENDED -> To enable extended optimizations (Includes level 1 + more complex optimizations like node fusions) + // ORT_ENABLE_ALL -> To Enable All possible opitmizations + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); //************************************************************************* // create session and load model into memory @@ -43,7 +44,7 @@ int main(int argc, char* argv[]) { //************************************************************************* // print model input layer (node names, types, shape etc.) - Ort::Allocator allocator = Ort::Allocator::CreateDefault(); + Ort::AllocatorWithDefaultOptions allocator; // print number of model input nodes size_t num_input_nodes = session.GetInputCount(); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp index 11dae1ab52197..bdc413715281c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp @@ -34,11 +34,7 @@ int main(int argc, char* argv[]) { OrtSetSessionThreadPoolSize(session_options, 1); // Sets graph optimization level - // Available levels are - // 0 -> To disable all optimizations - // 1 -> To enable basic optimizations (Such as redundant node removals) - // 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions) - OrtSetSessionGraphOptimizationLevel(session_options, 1); + OrtSetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC); // Optionally add more execution providers via session_options // E.g. for CUDA include cuda_provider_factory.h and uncomment the following line: @@ -63,7 +59,7 @@ int main(int argc, char* argv[]) { size_t num_input_nodes; OrtStatus* status; OrtAllocator* allocator; - CHECK_STATUS(OrtCreateDefaultAllocator(&allocator)); + CHECK_STATUS(OrtGetAllocatorWithDefaultOptions(&allocator)); // print number of model input nodes status = OrtSessionGetInputCount(session, &num_input_nodes); @@ -101,7 +97,6 @@ int main(int argc, char* argv[]) { OrtReleaseTypeInfo(typeinfo); } - OrtReleaseAllocator(allocator); // Results should be... // Number of inputs = 1 diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 88bf5f83d4c8f..02888e5a52647 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; -using System.Numerics.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; using System.Threading.Tasks; using Xunit; @@ -17,6 +17,81 @@ public class InferenceTest private const string module = "onnxruntime.dll"; private const string propertiesFile = "Properties.txt"; + [Fact] + public void TestSessionOptions() + { + using (SessionOptions opt = new SessionOptions()) + { + Assert.NotNull(opt); + + // check default values of the properties + Assert.True(opt.EnableSequentialExecution); + Assert.True(opt.EnableMemoryPattern); + Assert.False(opt.EnableProfiling); + Assert.Equal("onnxruntime_profile_", opt.ProfileOutputPathPrefix); + Assert.True(opt.EnableCpuMemArena); + Assert.Equal("", opt.LogId); + Assert.Equal(LogLevel.Verbose, opt.LogVerbosityLevel); + Assert.Equal(0, opt.ThreadPoolSize); + Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_BASIC, opt.GraphOptimizationLevel); + + // try setting options + opt.EnableSequentialExecution = false; + Assert.False(opt.EnableSequentialExecution); + + opt.EnableMemoryPattern = false; + Assert.False(opt.EnableMemoryPattern); + + opt.EnableProfiling = true; + Assert.True(opt.EnableProfiling); + Assert.Equal("onnxruntime_profile_", opt.ProfileOutputPathPrefix); + + opt.ProfileOutputPathPrefix = "Ort_P_"; + Assert.Equal("Ort_P_", opt.ProfileOutputPathPrefix); + + opt.EnableCpuMemArena = false; + Assert.False(opt.EnableCpuMemArena); + + opt.LogId = "MyLogId"; + Assert.Equal("MyLogId", opt.LogId); + + opt.LogVerbosityLevel = LogLevel.Error; + Assert.Equal(LogLevel.Error, opt.LogVerbosityLevel); + + opt.ThreadPoolSize = 4; + Assert.Equal(4, opt.ThreadPoolSize); + + opt.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED; + Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, opt.GraphOptimizationLevel); + + Assert.Throws(() => { opt.GraphOptimizationLevel = (GraphOptimizationLevel)10; }); + } + } + + [Fact] + public void TestRunOptions() + { + using (var opt = new RunOptions()) + { + Assert.NotNull(opt); + + //verify default options + Assert.False(opt.Terminate); + Assert.Equal(LogLevel.Verbose, opt.LogVerbosityLevel); + Assert.Equal("", opt.LogTag); + + // try setting options + opt.Terminate = true; + Assert.True(opt.Terminate); + + opt.LogVerbosityLevel = LogLevel.Error; + Assert.Equal(LogLevel.Error, opt.LogVerbosityLevel); + + opt.LogTag = "MyLogTag"; + Assert.Equal("MyLogTag", opt.LogTag); + } + } + [Fact] public void CanCreateAndDisposeSessionWithModelPath() { @@ -51,18 +126,18 @@ public void CanCreateAndDisposeSessionWithModelPath() } [Theory] - [InlineData(0, true)] - [InlineData(0, false)] - [InlineData(2, true)] - [InlineData(2, false)] - private void CanRunInferenceOnAModel(uint graphOptimizationLevel, bool disableSequentialExecution) + [InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)] + [InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)] + [InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, true)] + [InlineData(GraphOptimizationLevel.ORT_ENABLE_EXTENDED, false)] + private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLevel, bool disableSequentialExecution) { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); // Set the graph optimization level for this session. SessionOptions options = new SessionOptions(); - options.SetSessionGraphOptimizationLevel(graphOptimizationLevel); - if (disableSequentialExecution) options.DisableSequentialExecution(); + options.GraphOptimizationLevel = graphOptimizationLevel; + if (disableSequentialExecution) options.EnableSequentialExecution = false; using (var session = new InferenceSession(modelPath, options)) { @@ -82,32 +157,51 @@ private void CanRunInferenceOnAModel(uint graphOptimizationLevel, bool disableSe // Run the inference using (var results = session.Run(container)) // results is an IReadOnlyList container { - Assert.Equal(1, results.Count); + validateRunResults(results); + } + + // Run Inference with RunOptions + using (var runOptions = new RunOptions()) + { + runOptions.LogTag = "CsharpTest"; + runOptions.Terminate = false; // TODO: Test terminate = true, it currently crashes + runOptions.LogVerbosityLevel = LogLevel.Error; + IReadOnlyCollection outputNames = session.OutputMetadata.Keys.ToList(); - float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out"); - // validate the results - foreach (var r in results) + using (var results = session.Run(container, outputNames, runOptions)) // results is an IReadOnlyList container { - Assert.Equal("softmaxout_1", r.Name); + validateRunResults(results); + } + } + } + } - var resultTensor = r.AsTensor(); - int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data - Assert.Equal(expectedDimensions.Length, resultTensor.Rank); + private void validateRunResults(IDisposableReadOnlyCollection results) + { + float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out"); + // validate the results + foreach (var r in results) + { + Assert.Equal(1, results.Count); + Assert.Equal("softmaxout_1", r.Name); - var resultDimensions = resultTensor.Dimensions; - for (int i = 0; i < expectedDimensions.Length; i++) - { - Assert.Equal(expectedDimensions[i], resultDimensions[i]); - } + var resultTensor = r.AsTensor(); + int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data + Assert.Equal(expectedDimensions.Length, resultTensor.Rank); - var resultArray = r.AsTensor().ToArray(); - Assert.Equal(expectedOutput.Length, resultArray.Length); - Assert.Equal(expectedOutput, resultArray, new floatComparer()); - } + var resultDimensions = resultTensor.Dimensions; + for (int i = 0; i < expectedDimensions.Length; i++) + { + Assert.Equal(expectedDimensions[i], resultDimensions[i]); } + + var resultArray = r.AsTensor().ToArray(); + Assert.Equal(expectedOutput.Length, resultArray.Length); + Assert.Equal(expectedOutput, resultArray, new floatComparer()); } } + [Fact] private void ThrowWrongInputName() { @@ -297,7 +391,7 @@ private void TestModelInputFloat() } } - [Fact(Skip = "Boolean tensor not supported yet")] + [Fact] private void TestModelInputBOOL() { // model takes 1x5 input of fixed type, echoes back @@ -355,15 +449,15 @@ private void TestModelInputDOUBLE() } - [Fact(Skip = "String tensor not supported yet")] + [Fact] private void TestModelInputSTRING() { // model takes 1x5 input of fixed type, echoes back - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_STRING.onnx"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_STRING.pb"); using (var session = new InferenceSession(modelPath)) { var container = new List(); - var tensorIn = new DenseTensor(new string[] { "a", "c", "d", "z", "f" }, new int[] { 1, 5 }); + var tensorIn = new DenseTensor(new string[] { "abc", "ced", "def", "", "frozen" }, new int[] { 1, 5 }); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); container.Add(nov); using (var res = session.Run(container)) @@ -374,7 +468,7 @@ private void TestModelInputSTRING() } } - [Fact(Skip = "Int8 not supported yet")] + [Fact] private void TestModelInputINT8() { // model takes 1x5 input of fixed type, echoes back @@ -638,6 +732,20 @@ private void TestModelSequenceOfMapStringFloat() } } + [Fact(Skip="The Model Serialization Test fails on linux. Test skipped until fixed. Serialization API should not be used before fix.")] + private void TestModelSerialization() + { + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); + string modelOutputPath = Path.Combine(Directory.GetCurrentDirectory(), "optimized-squeezenet.onnx"); + // Set the optimized model file path to assert that no exception are thrown. + SessionOptions options = new SessionOptions(); + options.OptimizedModelFilePath = modelOutputPath; + options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_BASIC; + var session = new InferenceSession(modelPath, options); + Assert.NotNull(session); + Assert.True(File.Exists(modelOutputPath)); + } + [GpuFact] private void TestGpu() { @@ -658,6 +766,7 @@ private void TestGpu() } } + [DllImport("kernel32", SetLastError = true)] static extern IntPtr LoadLibrary(string lpFileName); @@ -671,15 +780,17 @@ private void VerifyNativeMethodsExist() if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return; var entryPointNames = new[]{ - "OrtCreateEnv","OrtReleaseEnv","OrtGetErrorCode","OrtGetErrorMessage", - "OrtReleaseStatus","OrtCreateSession","OrtRun","OrtSessionGetInputCount", - "OrtSessionGetOutputCount","OrtSessionGetInputName","OrtSessionGetOutputName","OrtSessionGetInputTypeInfo", - "OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions", + "OrtCreateEnv","OrtReleaseEnv", + "OrtGetErrorCode","OrtGetErrorMessage", "OrtReleaseStatus", + "OrtCreateSession","OrtRun", + "OrtSessionGetInputCount", "OrtSessionGetOutputCount","OrtSessionGetInputName","OrtSessionGetOutputName", + "OrtSessionGetInputTypeInfo", "OrtSessionGetOutputTypeInfo","OrtReleaseSession", + "OrtCreateSessionOptions","OrtCloneSessionOptions", "OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling", "OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena", "OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSetSessionGraphOptimizationLevel", - "OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo", - "OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo", + "OrtSetOptimizedModelFilePath", "OrtSessionOptionsAppendExecutionProvider_CPU","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo", + "OrtGetAllocatorWithDefaultOptions","OrtAllocatorFree","OrtAllocatorGetInfo", "OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo", "OrtCastTypeInfoToTensorInfo","OrtGetTensorTypeAndShape","OrtGetTensorElementType","OrtGetDimensionsCount", "OrtGetDimensions","OrtGetTensorShapeElementCount","OrtReleaseValue"}; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj index 0f307ae7680f2..a271ff96d6b7a 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj @@ -10,8 +10,40 @@ $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake\$(Configuration) $(OnnxRuntimeCsharpRoot)\..\onnxruntime\core\protobuf $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + + + 7.2 + True + true + false + ..\..\OnnxRuntime.snk + + + + True + True + Tensors\TensorArithmetic.tt + + + True + True + Tensors\TensorOperations.tt + + + + + TextTemplatingFileGenerator + Tensors\TensorArithmetic.cs + + + TextTemplatingFileGenerator + Tensors\TensorOperations.cs + + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/NativeMemory.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/NativeMemory.cs new file mode 100644 index 0000000000000..019c3d58cd663 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/NativeMemory.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/NativeMemory.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Runtime.InteropServices; +using System.Runtime.CompilerServices; +using System.Threading; +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors.Tests +{ + public class NativeMemory : MemoryManager + { + private bool disposed = false; + private int refCount = 0; + private IntPtr memory; + private int length; + + public NativeMemory(IntPtr memory, int length) + { + this.memory = memory; + this.length = length; + } + + public unsafe NativeMemory(void* memory, int length) + { + this.memory = (IntPtr)memory; + this.length = length; + } + + ~NativeMemory() + { + Dispose(false); + } + + public static NativeMemory Allocate(int length) + { + // typically this would call into a native method appropriate for the platform + // or the constructors above would be used to wrap the native pointer + IntPtr memory = Marshal.AllocHGlobal(Marshal.SizeOf() * length); + return new NativeMemory(memory, length); + } + + public bool IsDisposed => disposed; + + public unsafe override Span GetSpan() => new Span((void*)memory, length); + + protected bool IsRetained => refCount > 0; + + public override MemoryHandle Pin(int elementIndex = 0) + { + unsafe + { + Retain(); + if ((uint)elementIndex > length) throw new ArgumentOutOfRangeException(nameof(elementIndex)); + void* pointer = Unsafe.Add((void*)memory, elementIndex); + return new MemoryHandle(pointer, default, this); + } + } + + public bool Release() + { + int newRefCount = Interlocked.Decrement(ref refCount); + + if (newRefCount < 0) + { + throw new InvalidOperationException("Unmatched Release/Retain"); + } + + return newRefCount != 0; + } + + public void Retain() + { + if (disposed) + { + throw new ObjectDisposedException(nameof(NativeMemory)); + } + + Interlocked.Increment(ref refCount); + } + + protected override void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + // typically this would call into a native method appropriate for the platform + Marshal.FreeHGlobal(memory); + memory = IntPtr.Zero; + + disposed = true; + } + + protected override bool TryGetArray(out ArraySegment arraySegment) + { + // cannot expose managed array + arraySegment = default; + return false; + } + + public override void Unpin() + { + Release(); + } + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.cs new file mode 100644 index 0000000000000..b1a476b27abe5 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.cs @@ -0,0 +1,16201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorArithmetic.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + internal interface ITensorArithmetic + { + T One { get; } + T Zero { get; } + void Add(Tensor left, Tensor right, Tensor result); + void Add(Tensor tensor, T scalar, Tensor result); + void And(Tensor left, Tensor right, Tensor result); + void And(Tensor tensor, T scalar, Tensor result); + void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result); + void Decrement(Tensor tensor, Tensor result); + void Divide(Tensor left, Tensor right, Tensor result); + void Divide(Tensor tensor, T scalar, Tensor result); + void Equals(Tensor left, Tensor right, Tensor result); + void GreaterThan(Tensor left, Tensor right, Tensor result); + void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result); + void Increment(Tensor tensor, Tensor result); + void LeftShift(Tensor tensor, int value, Tensor result); + void LessThan(Tensor left, Tensor right, Tensor result); + void LessThanOrEqual(Tensor left, Tensor right, Tensor result); + void Modulo(Tensor left, Tensor right, Tensor result); + void Modulo(Tensor tensor, T scalar, Tensor result); + void Multiply(Tensor left, Tensor right, Tensor result); + void Multiply(Tensor tensor, T scalar, Tensor result); + void NotEquals(Tensor left, Tensor right, Tensor result); + void Or(Tensor left, Tensor right, Tensor result); + void Or(Tensor tensor, T scalar, Tensor result); + void RightShift(Tensor tensor, int value, Tensor result); + void Subtract(Tensor left, Tensor right, Tensor result); + void Subtract(Tensor tensor, T scalar, Tensor result); + void UnaryMinus(Tensor tensor, Tensor result); + void UnaryPlus(Tensor tensor, Tensor result); + void Xor(Tensor left, Tensor right, Tensor result); + void Xor(Tensor tensor, T scalar, Tensor result); + } + + internal static class TensorArithmetic + { + public static ITensorArithmetic Instance => TensorArithmetic.GetArithmetic(); + } + + internal static class TensorArithmetic + { + public static ITensorArithmetic GetArithmetic() + { + if (typeof(T) == typeof(bool)) + { + return (ITensorArithmetic)new BoolArithmetic(); + } + else if (typeof(T) == typeof(byte)) + { + return (ITensorArithmetic)new ByteArithmetic(); + } + else if (typeof(T) == typeof(char)) + { + return (ITensorArithmetic)new CharArithmetic(); + } + else if (typeof(T) == typeof(decimal)) + { + return (ITensorArithmetic)new DecimalArithmetic(); + } + else if (typeof(T) == typeof(double)) + { + return (ITensorArithmetic)new DoubleArithmetic(); + } + else if (typeof(T) == typeof(float)) + { + return (ITensorArithmetic)new FloatArithmetic(); + } + else if (typeof(T) == typeof(int)) + { + return (ITensorArithmetic)new IntArithmetic(); + } + else if (typeof(T) == typeof(long)) + { + return (ITensorArithmetic)new LongArithmetic(); + } + else if (typeof(T) == typeof(sbyte)) + { + return (ITensorArithmetic)new SByteArithmetic(); + } + else if (typeof(T) == typeof(short)) + { + return (ITensorArithmetic)new ShortArithmetic(); + } + else if (typeof(T) == typeof(uint)) + { + return (ITensorArithmetic)new UIntArithmetic(); + } + else if (typeof(T) == typeof(ulong)) + { + return (ITensorArithmetic)new ULongArithmetic(); + } + else if (typeof(T) == typeof(ushort)) + { + return (ITensorArithmetic)new UShortArithmetic(); + } + return null; + } + } + + internal class BoolArithmetic : ITensorArithmetic + { + public bool One => true; + public bool Zero => false; + + public void Add(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Add(Tensor tensor, bool scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, bool scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + throw new NotSupportedException(); + } + public void Decrement(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Divide(Tensor tensor, bool scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Increment(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Modulo(Tensor tensor, bool scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Multiply(Tensor tensor, bool scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, bool scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Subtract(Tensor tensor, bool scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, bool scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (bool)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Add(DenseTensor tensor, bool scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, bool scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Divide(DenseTensor tensor, bool scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Modulo(DenseTensor tensor, bool scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Multiply(DenseTensor tensor, bool scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, bool scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Subtract(DenseTensor tensor, bool scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, bool scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (bool)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class ByteArithmetic : ITensorArithmetic + { + public byte One => 1; + public byte Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + byte sum = (byte)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (byte)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, byte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (byte)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + byte sum = (byte)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (byte)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, byte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (byte)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class CharArithmetic : ITensorArithmetic + { + public char One => (char)1; + public char Zero => (char)0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + char sum = (char)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (char)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, char scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (char)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + char sum = (char)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (char)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, char scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (char)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (char)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class DecimalArithmetic : ITensorArithmetic + { + public decimal One => 1; + public decimal Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, decimal scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void And(Tensor tensor, decimal scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + decimal sum = (decimal)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (decimal)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, decimal scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, decimal scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, decimal scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Or(Tensor tensor, decimal scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, decimal scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (decimal)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Xor(Tensor tensor, decimal scalar, Tensor result) + { + throw new NotSupportedException(); + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, decimal scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void And(DenseTensor tensor, decimal scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + decimal sum = (decimal)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (decimal)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, decimal scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, decimal scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, decimal scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Or(DenseTensor tensor, decimal scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, decimal scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (decimal)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (decimal)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Xor(DenseTensor tensor, decimal scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + } + internal class DoubleArithmetic : ITensorArithmetic + { + public double One => 1.0; + public double Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, double scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void And(Tensor tensor, double scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + double sum = (double)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (double)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, double scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, double scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, double scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Or(Tensor tensor, double scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, double scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (double)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Xor(Tensor tensor, double scalar, Tensor result) + { + throw new NotSupportedException(); + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, double scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void And(DenseTensor tensor, double scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + double sum = (double)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (double)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, double scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, double scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, double scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Or(DenseTensor tensor, double scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, double scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (double)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (double)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Xor(DenseTensor tensor, double scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + } + internal class FloatArithmetic : ITensorArithmetic + { + public float One => 1.0f; + public float Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, float scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void And(Tensor tensor, float scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + float sum = (float)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (float)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, float scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, float scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, float scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Or(Tensor tensor, float scalar, Tensor result) + { + throw new NotSupportedException(); + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + throw new NotSupportedException(); + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, float scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (float)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + throw new NotSupportedException(); + } + public void Xor(Tensor tensor, float scalar, Tensor result) + { + throw new NotSupportedException(); + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, float scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void And(DenseTensor tensor, float scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + float sum = (float)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (float)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, float scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, float scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, float scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Or(DenseTensor tensor, float scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, float scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (float)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (float)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + throw new NotSupportedException(); + } + public void Xor(DenseTensor tensor, float scalar, DenseTensor result) + { + throw new NotSupportedException(); + } + } + internal class IntArithmetic : ITensorArithmetic + { + public int One => 1; + public int Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + int sum = (int)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (int)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, int scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (int)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + int sum = (int)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (int)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, int scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (int)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (int)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class LongArithmetic : ITensorArithmetic + { + public long One => 1; + public long Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + long sum = (long)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (long)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, long scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (long)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + long sum = (long)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (long)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, long scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (long)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (long)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class SByteArithmetic : ITensorArithmetic + { + public sbyte One => 1; + public sbyte Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + sbyte sum = (sbyte)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (sbyte)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, sbyte scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (sbyte)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + sbyte sum = (sbyte)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (sbyte)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, sbyte scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (sbyte)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class ShortArithmetic : ITensorArithmetic + { + public short One => 1; + public short Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + short sum = (short)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (short)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)-tensor[indices]; + } + + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, short scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (short)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + short sum = (short)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (short)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)-tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)-tensorSpan[op1Index]; + + } + } + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, short scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (short)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (short)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class UIntArithmetic : ITensorArithmetic + { + public uint One => 1; + public uint Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + uint sum = (uint)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (uint)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, uint scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (uint)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + uint sum = (uint)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (uint)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, uint scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (uint)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class ULongArithmetic : ITensorArithmetic + { + public ulong One => 1; + public ulong Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + ulong sum = (ulong)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (ulong)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, ulong scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ulong)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + ulong sum = (ulong)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (ulong)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, ulong scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ulong)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } + internal class UShortArithmetic : ITensorArithmetic + { + public ushort One => 1; + public ushort Zero => 0; + + public void Add(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] + right[indices]); + } + + } + public void Add(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] + scalar); + } + + } + public void And(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] & right[indices]); + } + + } + public void And(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] & scalar); + } + + } + public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + ushort sum = (ushort)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (ushort)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } + } + public void Decrement(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]--; + } + + } + public void Divide(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] / right[indices]); + } + + } + public void Divide(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] / scalar); + } + + } + public void Equals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] == right[indices]; + } + + } + public void GreaterThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] > right[indices]; + } + + } + public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] >= right[indices]; + } + + } + public void Increment(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices]++; + } + + } + public void LeftShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] << value); + } + + } + public void LessThan(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] < right[indices]; + } + + } + public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] <= right[indices]; + } + + } + public void Modulo(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] % right[indices]); + } + + } + public void Modulo(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] % scalar); + } + + } + public void Multiply(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] * right[indices]); + } + + } + public void Multiply(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] * scalar); + } + + } + public void NotEquals(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = left[indices] != right[indices]; + } + + } + public void Or(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] | right[indices]); + } + + } + public void Or(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] | scalar); + } + + } + public void RightShift(Tensor tensor, int value, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] >> value); + } + + } + public void Subtract(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] - right[indices]); + } + + } + public void Subtract(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] - scalar); + } + + } + public void UnaryMinus(Tensor tensor, Tensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(Tensor tensor, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)+tensor[indices]; + } + + } + public void Xor(Tensor left, Tensor right, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(left[indices] ^ right[indices]); + } + + } + public void Xor(Tensor tensor, ushort scalar, Tensor result) + { + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < result.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + result[indices] = (ushort)(tensor[indices] ^ scalar); + } + + } + + public void Add(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] + rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] + rightSpan[op2Index]); + + } + } + } + public void Add(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] + scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] + scalar); + + } + } + } + public void And(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] & rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] & rightSpan[op2Index]); + + } + } + } + public void And(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] & scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] & scalar); + + } + } + } + public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) + { + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + ushort sum = (ushort)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (ushort)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } + } + public void Decrement(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]--; + } + } + public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] / rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] / rightSpan[op2Index]); + + } + } + } + public void Divide(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] / scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] / scalar); + + } + } + } + public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] == rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; + + } + } + } + public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] > rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; + + } + } + } + public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] >= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; + + } + } + } + public void Increment(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i]++; + } + } + public void LeftShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] << value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] << value); + + } + } + } + public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] < rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; + + } + } + } + public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] <= rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; + + } + } + } + public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] % rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] % rightSpan[op2Index]); + + } + } + } + public void Modulo(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] % scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] % scalar); + + } + } + } + public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] * rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] * rightSpan[op2Index]); + + } + } + } + public void Multiply(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] * scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] * scalar); + + } + } + } + public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = leftSpan[i] != rightSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; + + } + } + } + public void Or(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] | rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] | rightSpan[op2Index]); + + } + } + } + public void Or(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] | scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] | scalar); + + } + } + } + public void RightShift(DenseTensor tensor, int value, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] >> value); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] >> value); + + } + } + } + public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] - rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] - rightSpan[op2Index]); + + } + } + } + public void Subtract(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] - scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] - scalar); + + } + } + } + public void UnaryMinus(DenseTensor tensor, DenseTensor result) + { + throw new NotSupportedException(); + } + public void UnaryPlus(DenseTensor tensor, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)+tensorSpan[i]; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)+tensorSpan[op1Index]; + + } + } + } + public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(leftSpan[i] ^ rightSpan[i]); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + !left.IsReversedStride ? left.strides : + right.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + left.IsReversedStride ? left.strides : + right.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] ^ rightSpan[op2Index]); + + } + } + } + public void Xor(DenseTensor tensor, ushort scalar, DenseTensor result) + { + + var resultSpan = result.Buffer.Span; + var tensorSpan = tensor.Buffer.Span; + if (result.IsReversedStride == tensor.IsReversedStride) + { + for(int i = 0; i < resultSpan.Length; i++) + { + resultSpan[i] = (ushort)(tensorSpan[i] ^ scalar); + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !result.IsReversedStride ? result.strides : + tensor.strides; + var columnMajorStrides = result.IsReversedStride ? result.strides : + tensor.strides; + for(;rowMajorIndex < resultSpan.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] ^ scalar); + + } + } + } + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.tt b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.tt new file mode 100644 index 0000000000000..dc7741052f702 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorArithmetic.tt @@ -0,0 +1,249 @@ +<#@ template debug="false" hostspecific="false" language="C#" #> +<#@ assembly name="System.Core" #> +<#@ output extension=".cs" #> +<#@ include file="TensorTemplate.ttinclude" #>// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorArithmetic.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + internal interface ITensorArithmetic + { + T One { get; } + T Zero { get; } +<# foreach (MethodConfiguration method in methodConfiguration) { #> + <#= method.GetResultMethodSignature("Tensor", "T")#>; +<# } #> + } + + internal static class TensorArithmetic + { + public static ITensorArithmetic Instance => TensorArithmetic.GetArithmetic(); + } + + internal static class TensorArithmetic + { + public static ITensorArithmetic GetArithmetic() + { +<# foreach (TypeConfiguration type in typeConfiguration) { #> + <#=GenerateIfStatementHeader(type)#> + { + return (ITensorArithmetic)new <#=type.ClassPrefix#>Arithmetic(); + } +<# } #> + return null; + } + } + +<# foreach (TypeConfiguration type in typeConfiguration) { #> + internal class <#=type.ClassPrefix#>Arithmetic : ITensorArithmetic<<#=type.TypeName#>> + { + public <#=type.TypeName#> One => <#=type.OneLiteral#>; + public <#=type.TypeName#> Zero => <#=type.ZeroLiteral#>; + +<# foreach (MethodConfiguration method in methodConfiguration) { #> + public <#= method.GetResultMethodSignature("Tensor", type.TypeName)#> + { +<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> + throw new NotSupportedException(); +<# } else if (method.Operator != null) { #> + + Span indices = new Span(new int[result.Rank]); + for(int i = 0; i < <#= method.ResultName #>.Length; i++) + { + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); + <#=method.GetElementOperation(type.TypeName, "[indices]")#>; + } + +<# } else if (method.MethodName == "Contract") {#> + var leftIndices = new int[left.Rank]; + var rightIndices = new int[right.Rank]; + var resultIndices = new int[result.Rank]; + + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) + { + <#=type.TypeName#> sum = (<#=type.TypeName#>)0; + + ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + // todo, make this more efficient + ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); + ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); + + sum += (<#=type.TypeName#>)(left[leftIndices] * right[rightIndices]); + } + + result[resultIndices] = sum; + } +<# } #> + } +<# } #> + +<# foreach (MethodConfiguration method in methodConfiguration) { #> + public <#= method.GetResultMethodSignature("DenseTensor", type.TypeName)#> + { +<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> + throw new NotSupportedException(); +<# } else if (method.Operator != null) { #> + +<# if (method.MethodType == MethodType.UnaryInPlace) { #> + var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; + var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; + for(int i = 0; i < <#=method.ResultName #>Span.Length; i++) + { + <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; + } +<# } else {#> + var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; + var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; +<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> + var <#=method.Op2Name #>Span = <#=method.Op2Name #>.Buffer.Span; +<# } #> + if <#= method.GetLinearOperationCheck() #> + { + for(int i = 0; i < <#= method.ResultName #>Span.Length; i++) + { + <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; + } + } + else + { + int rowMajorIndex = 0; + int colMajorIndex = 0; + + ref int resultIndex = ref <#= method.ResultName #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + ref int op1Index = ref <#= method.Op1Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + +<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> + ref int op2Index = ref <#= method.Op2Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; + + var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : + !<#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : + <#= method.Op2Name #>.strides; + var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : + <#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : + <#= method.Op2Name #>.strides; +<# } else {#> + var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : + <#= method.Op1Name #>.strides; + var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : + <#= method.Op1Name #>.strides; +<# } #> + for(;rowMajorIndex < <#= method.ResultName #>Span.Length; rowMajorIndex++) + { + colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); + + <#=method.GetElementOperation(type.TypeName, "Span[resultIndex]", "Span[op1Index]", "Span[op2Index]")#>; + + } + } +<# } #> +<# } else if (method.MethodName == "Contract") {#> + var summingDimensions = new int[leftAxes.Length]; + for(int i = 0; i < leftAxes.Length; i++) + { + summingDimensions[i] = left.dimensions[leftAxes[i]]; + } + + var summingStrides = ArrayUtilities.GetStrides(summingDimensions); + int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); + + var resultStrides = result.strides; + + // translates from result index to left non-summing dimensions' index portion + // since left non-summing dimensions are given precedence in result, the end is zero-padded + int[] leftNonSummingStrides = new int[result.Rank]; + + // translates from summing index to left summing dimensions' index portion + int[] leftSummingStrides = new int[leftAxes.Length]; + ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); + + // translates from result index to right non-summing dimensions' index portion + int[] rightNonSummingStrides = new int[result.Rank]; + // right non-summing dimensions appear after left non-summing dimensions. + int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); + + // translates from summing index to right summing dimensions' index portion + int[] rightSummingStrides = new int[rightAxes.Length]; + ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); + + var resultSpan = result.Buffer.Span; + var leftSpan = left.Buffer.Span; + var rightSpan = right.Buffer.Span; + + for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) + { + <#=type.TypeName#> sum = (<#=type.TypeName#>)0; + + int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); + int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); + + for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) + { + int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); + int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); + + int leftIndex = leftIndexNonSumming + leftIndexSumming; + int rightIndex = rightIndexNonSumming + rightIndexSumming; + + sum += (<#=type.TypeName#>)(leftSpan[leftIndex] * rightSpan[rightIndex]); + } + + resultSpan[resultIndex] = sum; + } +<# } #> + } +<# } #> + } +<# } #> +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorExtensions.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorExtensions.cs new file mode 100644 index 0000000000000..ee9c2438428c0 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorExtensions.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorExtensions.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + public static partial class TensorExtensions + { + private static int[] s_zeroArray = new[] { 0 }; + private static int[] s_oneArray = new[] { 1 }; + + internal static Tensor MatrixMultiply(this Tensor left, Tensor right) + { + if (left.Rank != 2) + { + throw new InvalidOperationException($"{nameof(MatrixMultiply)} is only valid for a {nameof(Tensor)} of {nameof(left.Rank)} 2."); + } + + if (right.Rank != 2) + { + throw new ArgumentException($"{nameof(Tensor)} {nameof(right)} must have {nameof(left.Rank)} 2.", nameof(right)); + } + + if (left.dimensions[1] != right.dimensions[0]) + { + throw new ArgumentException($"{nameof(Tensor)} {nameof(right)} must have first dimension of {left.dimensions[1]}.", nameof(right)); + } + + return TensorOperations.Contract(left, right, s_oneArray, s_zeroArray); + } + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.cs new file mode 100644 index 0000000000000..2efda4872edec --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.cs @@ -0,0 +1,750 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorOperations.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + public static partial class TensorOperations + { + internal static void ValidateBinaryArgs(Tensor left, Tensor right) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < left.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + } + } + + internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank != result.Rank || left.Length != result.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(result)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank != result.Rank || left.Length != result.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(result)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static void ValidateArgs(Tensor tensor) + { + if (tensor.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); + } + } + + internal static void ValidateArgs(Tensor tensor, Tensor result) + { + if (tensor.Rank != result.Rank || tensor.Length != result.Length) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + + if (tensor.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (tensor.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) + { + if (leftAxes == null) + { + throw new ArgumentNullException(nameof(left)); + } + + if (rightAxes == null) + { + throw new ArgumentNullException(nameof(left)); + } + + if (leftAxes.Length != rightAxes.Length) + { + throw new ArgumentException($"{nameof(leftAxes)} and {nameof(rightAxes)} must have the same length, but were {leftAxes.Length} and {rightAxes.Length}, respectively."); + } + + for (int i = 0; i < leftAxes.Length; i++) + { + var leftAxis = leftAxes[i]; + + if (leftAxis >= left.Rank) + { + throw new ArgumentOutOfRangeException($"{nameof(leftAxes)}[{i}] was set to axis index {leftAxis} which exceeds the Rank of {left}."); + } + + var leftDimension = left.dimensions[leftAxis]; + + var rightAxis = rightAxes[i]; + + if (rightAxis >= right.Rank) + { + throw new ArgumentOutOfRangeException($"{nameof(rightAxes)}[{i}] was set to axis index {rightAxis} which exceeds the Rank of {right}."); + } + + var rightDimension = right.dimensions[rightAxis]; + + if (leftDimension != rightDimension) + { + throw new ArgumentOutOfRangeException($"Tensors may only be contracted on axes of the same length, but {nameof(leftAxes)} index {i} was length {leftDimension} and {nameof(rightAxes)} index {i} was length {rightDimension}."); + } + } + + var leftNonSummingDimensions = left.Rank - leftAxes.Length; + var rightNonSummingDimensions = right.Rank - rightAxes.Length; + var resultDimensions = new int[leftNonSummingDimensions + rightNonSummingDimensions]; + int dimensionsIndex = 0; + + Action, int[]> fillDimensions = (tensor, axes) => + { + for (int i = 0; i < tensor.Rank; i++) + { + var skip = false; + foreach (var contractionIndex in axes) + { + if (contractionIndex == i) + { + skip = true; + break; + } + } + + if (!skip) + { + resultDimensions[dimensionsIndex++] = tensor.dimensions[i]; + } + } + }; + + fillDimensions(left, leftAxes); + fillDimensions(right, rightAxes); + + return resultDimensions; + } + + internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var expectedDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); + + if (result.Rank != expectedDimensions.Length) + { + throw new ArgumentException($"{nameof(result)} should have {expectedDimensions.Length} dimensions but had {result.Rank}."); + } + + for (int i = 0; i < expectedDimensions.Length; i++) + { + if (result.dimensions[i] != expectedDimensions[i]) + { + throw new ArgumentException($"{nameof(result)} dimension {i} should be {expectedDimensions[i]} but was {result.dimensions[i]}."); + } + } + + return expectedDimensions; + } + + internal static void Add(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Add(left, right, result); + } + + internal static Tensor Add(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Add(left, right, result); + + return result; + } + + internal static void Add(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Add(tensor, scalar, result); + } + + internal static Tensor Add(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Add(tensor, scalar, result); + + return result; + } + + internal static void And(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.And(left, right, result); + } + + internal static Tensor And(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.And(left, right, result); + + return result; + } + + internal static void And(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.And(tensor, scalar, result); + } + + internal static Tensor And(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.And(tensor, scalar, result); + + return result; + } + + internal static void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var resultDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes, result); + + TensorArithmetic.Instance.Contract(left, right, leftAxes, rightAxes, result); + } + + internal static Tensor Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) + { + var resultDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); + + var result = left.CloneEmpty(resultDimensions); + + TensorArithmetic.Instance.Contract(left, right, leftAxes, rightAxes, result); + + return result; + } + + internal static void Decrement(Tensor tensor, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Decrement(tensor, result); + } + + internal static Tensor Decrement(Tensor tensor) + { + ValidateArgs(tensor); + + var result = tensor.Clone(); + + TensorArithmetic.Instance.Decrement(tensor, result); + + return result; + } + + internal static void Divide(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Divide(left, right, result); + } + + internal static Tensor Divide(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Divide(left, right, result); + + return result; + } + + internal static void Divide(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Divide(tensor, scalar, result); + } + + internal static Tensor Divide(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Divide(tensor, scalar, result); + + return result; + } + + internal static void Equals(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Equals(left, right, result); + } + + internal static Tensor Equals(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Equals(left, right, result); + + return result; + } + + internal static void GreaterThan(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.GreaterThan(left, right, result); + } + + internal static Tensor GreaterThan(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.GreaterThan(left, right, result); + + return result; + } + + internal static void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.GreaterThanOrEqual(left, right, result); + } + + internal static Tensor GreaterThanOrEqual(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.GreaterThanOrEqual(left, right, result); + + return result; + } + + internal static void Increment(Tensor tensor, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Increment(tensor, result); + } + + internal static Tensor Increment(Tensor tensor) + { + ValidateArgs(tensor); + + var result = tensor.Clone(); + + TensorArithmetic.Instance.Increment(tensor, result); + + return result; + } + + internal static void LeftShift(Tensor tensor, int value, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.LeftShift(tensor, value, result); + } + + internal static Tensor LeftShift(Tensor tensor, int value) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.LeftShift(tensor, value, result); + + return result; + } + + internal static void LessThan(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.LessThan(left, right, result); + } + + internal static Tensor LessThan(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.LessThan(left, right, result); + + return result; + } + + internal static void LessThanOrEqual(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.LessThanOrEqual(left, right, result); + } + + internal static Tensor LessThanOrEqual(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.LessThanOrEqual(left, right, result); + + return result; + } + + internal static void Modulo(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Modulo(left, right, result); + } + + internal static Tensor Modulo(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Modulo(left, right, result); + + return result; + } + + internal static void Modulo(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Modulo(tensor, scalar, result); + } + + internal static Tensor Modulo(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Modulo(tensor, scalar, result); + + return result; + } + + internal static void Multiply(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Multiply(left, right, result); + } + + internal static Tensor Multiply(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Multiply(left, right, result); + + return result; + } + + internal static void Multiply(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Multiply(tensor, scalar, result); + } + + internal static Tensor Multiply(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Multiply(tensor, scalar, result); + + return result; + } + + internal static void NotEquals(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.NotEquals(left, right, result); + } + + internal static Tensor NotEquals(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.NotEquals(left, right, result); + + return result; + } + + internal static void Or(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Or(left, right, result); + } + + internal static Tensor Or(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Or(left, right, result); + + return result; + } + + internal static void Or(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Or(tensor, scalar, result); + } + + internal static Tensor Or(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Or(tensor, scalar, result); + + return result; + } + + internal static void RightShift(Tensor tensor, int value, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.RightShift(tensor, value, result); + } + + internal static Tensor RightShift(Tensor tensor, int value) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.RightShift(tensor, value, result); + + return result; + } + + internal static void Subtract(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Subtract(left, right, result); + } + + internal static Tensor Subtract(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Subtract(left, right, result); + + return result; + } + + internal static void Subtract(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Subtract(tensor, scalar, result); + } + + internal static Tensor Subtract(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Subtract(tensor, scalar, result); + + return result; + } + + internal static void UnaryMinus(Tensor tensor, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.UnaryMinus(tensor, result); + } + + internal static Tensor UnaryMinus(Tensor tensor) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.UnaryMinus(tensor, result); + + return result; + } + + internal static void UnaryPlus(Tensor tensor, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.UnaryPlus(tensor, result); + } + + internal static Tensor UnaryPlus(Tensor tensor) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.UnaryPlus(tensor, result); + + return result; + } + + internal static void Xor(Tensor left, Tensor right, Tensor result) + { + ValidateBinaryArgs(left, right, result); + + TensorArithmetic.Instance.Xor(left, right, result); + } + + internal static Tensor Xor(Tensor left, Tensor right) + { + ValidateBinaryArgs(left, right); + + var result = left.CloneEmpty(); + + TensorArithmetic.Instance.Xor(left, right, result); + + return result; + } + + internal static void Xor(Tensor tensor, T scalar, Tensor result) + { + ValidateArgs(tensor, result); + + TensorArithmetic.Instance.Xor(tensor, scalar, result); + } + + internal static Tensor Xor(Tensor tensor, T scalar) + { + ValidateArgs(tensor); + + var result = tensor.CloneEmpty(); + + TensorArithmetic.Instance.Xor(tensor, scalar, result); + + return result; + } + + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.tt b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.tt new file mode 100644 index 0000000000000..627aa7625cbd2 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorOperations.tt @@ -0,0 +1,251 @@ +<#@ template debug="false" hostspecific="false" language="C#" #> +<#@ assembly name="System.Core" #> +<#@ output extension=".cs" #> +<#@ include file="TensorTemplate.ttinclude" #>// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorOperations.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors +{ + public static partial class TensorOperations + { + internal static void ValidateBinaryArgs(Tensor left, Tensor right) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < left.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + } + } + + internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank != result.Rank || left.Length != result.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(result)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) + { + if (left.Rank != right.Rank || left.Length != right.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.Rank != result.Rank || left.Length != result.Length) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(result)); + } + + if (left.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (left.dimensions[i] != right.dimensions[i]) + { + throw new ArgumentException("Operands must have matching dimensions", nameof(right)); + } + + if (left.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static void ValidateArgs(Tensor tensor) + { + if (tensor.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); + } + } + + internal static void ValidateArgs(Tensor tensor, Tensor result) + { + if (tensor.Rank != result.Rank || tensor.Length != result.Length) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + + if (tensor.Rank == 0) + { + throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); + } + + for (int i = 0; i < result.Rank; i++) + { + if (tensor.dimensions[i] != result.dimensions[i]) + { + throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); + } + } + } + + internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) + { + if (leftAxes == null) + { + throw new ArgumentNullException(nameof(left)); + } + + if (rightAxes == null) + { + throw new ArgumentNullException(nameof(left)); + } + + if (leftAxes.Length != rightAxes.Length) + { + throw new ArgumentException($"{nameof(leftAxes)} and {nameof(rightAxes)} must have the same length, but were {leftAxes.Length} and {rightAxes.Length}, respectively."); + } + + for (int i = 0; i < leftAxes.Length; i++) + { + var leftAxis = leftAxes[i]; + + if (leftAxis >= left.Rank) + { + throw new ArgumentOutOfRangeException($"{nameof(leftAxes)}[{i}] was set to axis index {leftAxis} which exceeds the Rank of {left}."); + } + + var leftDimension = left.dimensions[leftAxis]; + + var rightAxis = rightAxes[i]; + + if (rightAxis >= right.Rank) + { + throw new ArgumentOutOfRangeException($"{nameof(rightAxes)}[{i}] was set to axis index {rightAxis} which exceeds the Rank of {right}."); + } + + var rightDimension = right.dimensions[rightAxis]; + + if (leftDimension != rightDimension) + { + throw new ArgumentOutOfRangeException($"Tensors may only be contracted on axes of the same length, but {nameof(leftAxes)} index {i} was length {leftDimension} and {nameof(rightAxes)} index {i} was length {rightDimension}."); + } + } + + var leftNonSummingDimensions = left.Rank - leftAxes.Length; + var rightNonSummingDimensions = right.Rank - rightAxes.Length; + var resultDimensions = new int[leftNonSummingDimensions + rightNonSummingDimensions]; + int dimensionsIndex = 0; + + Action, int[]> fillDimensions = (tensor, axes) => + { + for (int i = 0; i < tensor.Rank; i++) + { + var skip = false; + foreach (var contractionIndex in axes) + { + if (contractionIndex == i) + { + skip = true; + break; + } + } + + if (!skip) + { + resultDimensions[dimensionsIndex++] = tensor.dimensions[i]; + } + } + }; + + fillDimensions(left, leftAxes); + fillDimensions(right, rightAxes); + + return resultDimensions; + } + + internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) + { + var expectedDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); + + if (result.Rank != expectedDimensions.Length) + { + throw new ArgumentException($"{nameof(result)} should have {expectedDimensions.Length} dimensions but had {result.Rank}."); + } + + for (int i = 0; i < expectedDimensions.Length; i++) + { + if (result.dimensions[i] != expectedDimensions[i]) + { + throw new ArgumentException($"{nameof(result)} dimension {i} should be {expectedDimensions[i]} but was {result.dimensions[i]}."); + } + } + + return expectedDimensions; + } + +<# foreach (MethodConfiguration method in methodConfiguration) { #> + internal static <#= method.GetGenericResultMethodSignature("Tensor", "T")#> + { + <#= method.GetValidationMethod(true) #> + + TensorArithmetic.Instance.<#=method.MethodName#>(<#=method.GetCallArguments()#>, <#= method.ResultName #>); + } + + internal static <#= method.GetGenericMethodSignature("Tensor", "T")#> + { + <#= method.GetValidationMethod(false) #> + + var <#= method.ResultName #> = <#=method.InitializeResult("T")#>; + + TensorArithmetic.Instance.<#=method.MethodName#>(<#=method.GetCallArguments()#>, <#= method.ResultName #>); + + return <#= method.ResultName #>; + } + +<# } #> + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTemplate.ttinclude b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTemplate.ttinclude new file mode 100644 index 0000000000000..9448791a5db6c --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTemplate.ttinclude @@ -0,0 +1,328 @@ +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +<#+ + public class TypeConfiguration + { + public TypeConfiguration(string typeName, string classPrefix = null, string oneLiteral = "1", string zeroLiteral = "0", bool supportsNumeric = true, bool supportsBitwise = true, IEnumerable unsupportedMethods = null) + { + TypeName = typeName; + ClassPrefix = classPrefix ?? char.ToUpper(typeName[0]) + typeName.Substring(1); + OneLiteral = oneLiteral; + ZeroLiteral = zeroLiteral; + SupportsNumeric = supportsNumeric; + SupportsBitwise = supportsBitwise; + UnsupportedMethods = new HashSet(unsupportedMethods ?? Enumerable.Empty()); + } + + public string TypeName { get; } + public string ClassPrefix { get; } + public string OneLiteral { get; } + public string ZeroLiteral { get; } + + public bool SupportsNumeric { get; } + public bool SupportsBitwise { get; } + public ISet UnsupportedMethods { get; } + } + + public string GenerateIfStatementHeader(TypeConfiguration type) + { + string keyword = (type == typeConfiguration[0]) ? "if" : "else if"; + return $"{keyword} (typeof(T) == typeof({type.TypeName}))"; + } + + public TypeConfiguration[] typeConfiguration = new [] + { + new TypeConfiguration("bool", oneLiteral:"true", zeroLiteral:"false", supportsNumeric: false, unsupportedMethods: new[] {"LeftShift", "RightShift"}), + new TypeConfiguration("byte"), + new TypeConfiguration("char", oneLiteral:"(char)1", zeroLiteral:"(char)0"), + new TypeConfiguration("decimal", supportsBitwise: false), + new TypeConfiguration("double", oneLiteral:"1.0", supportsBitwise: false), + new TypeConfiguration("float", oneLiteral:"1.0f", supportsBitwise: false), + new TypeConfiguration("int"), + new TypeConfiguration("long"), + new TypeConfiguration("sbyte", classPrefix:"SByte"), + new TypeConfiguration("short"), + new TypeConfiguration("uint", classPrefix:"UInt", unsupportedMethods: new[] {"UnaryMinus"}), + new TypeConfiguration("ulong", classPrefix:"ULong", unsupportedMethods: new[] {"UnaryMinus"}), + new TypeConfiguration("ushort", classPrefix:"UShort", unsupportedMethods: new[] {"UnaryMinus"}) + }; + + public enum MethodType + { + Unary, + UnaryInPlace, + BinaryScalar, + BinaryInt, + Binary, + Comparison, + Contraction + } + + public class MethodConfiguration + { + public MethodConfiguration(string methodName, MethodType methodType, string op = null, bool isNumeric = false, bool isBitwise = false) + { + MethodName = methodName; + MethodType = methodType; + Operator = op; + IsNumeric = isNumeric; + IsBitwise = isBitwise; + } + + public string ResultName => "result"; + + public string Op1Name + { + get + { + switch (MethodType) + { + case MethodType.Unary: + case MethodType.UnaryInPlace: + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + return "tensor"; + case MethodType.Binary: + case MethodType.Comparison: + case MethodType.Contraction: + return "left"; + default: + throw new ArgumentException(); + }; + } + } + + public string Op2Name + { + get + { + switch (MethodType) + { + case MethodType.BinaryScalar: + return "scalar"; + case MethodType.BinaryInt: + return "value"; + case MethodType.Binary: + case MethodType.Comparison: + case MethodType.Contraction: + return "right"; + case MethodType.Unary: + case MethodType.UnaryInPlace: + default: + throw new ArgumentException(); + }; + } + } + + public string MethodName { get; } + public MethodType MethodType { get; } + public string Operator { get; } + + public string GetGenericMethodSignature(string tensorType, string genericType) + { + var resultType = GetResultType(tensorType, genericType); + var arguments = GetMethodArguments(tensorType, genericType); + + return $"{resultType} {MethodName}<{genericType}>({arguments})"; + } + + public string GetGenericResultMethodSignature(string tensorType, string genericType) + { + var resultType = GetResultType(tensorType, genericType); + var arguments = GetMethodArguments(tensorType, genericType); + + return $"void {MethodName}<{genericType}>({arguments}, {resultType} {ResultName})"; + } + + public string GetResultMethodSignature(string tensorType, string genericType) + { + var resultType = GetResultType(tensorType, genericType); + var arguments = GetMethodArguments(tensorType, genericType); + + return $"void {MethodName}({arguments}, {resultType} {ResultName})"; + } + + public string GetMethodArguments(string tensorType, string genericType) + { + switch (MethodType) + { + case MethodType.Unary: + case MethodType.UnaryInPlace: + return $"{tensorType}<{genericType}> {Op1Name}"; + case MethodType.BinaryScalar: + return $"{tensorType}<{genericType}> {Op1Name}, {genericType} {Op2Name}"; + case MethodType.BinaryInt: + return $"{tensorType}<{genericType}> {Op1Name}, int {Op2Name}"; + case MethodType.Binary: + case MethodType.Comparison: + return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}"; + case MethodType.Contraction: + return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}, int[] leftAxes, int[] rightAxes"; + default: + throw new ArgumentException(); + } + } + + public string GetCallArguments() + { + switch (MethodType) + { + case MethodType.Unary: + case MethodType.UnaryInPlace: + return $"{Op1Name}"; + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + case MethodType.Binary: + case MethodType.Comparison: + return $"{Op1Name}, {Op2Name}"; + case MethodType.Contraction: + return "left, right, leftAxes, rightAxes"; + default: + throw new ArgumentException(); + } + } + + public string GetValidationMethod(bool includeResult) + { + var suffix = includeResult ? ", result" : ""; + switch (MethodType) + { + case MethodType.Unary: + case MethodType.UnaryInPlace: + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + return $"ValidateArgs({Op1Name}{suffix});"; + case MethodType.Binary: + case MethodType.Comparison: + return $"ValidateBinaryArgs({Op1Name}, {Op2Name}{suffix});"; + case MethodType.Contraction: + return $"var resultDimensions = ValidateContractArgs({Op1Name}, {Op2Name}, leftAxes, rightAxes{suffix});"; + default: + throw new ArgumentException(); + } + } + + public string GetResultType(string tensorType, string typeName) + { + switch (MethodType) + { + case MethodType.Unary: + case MethodType.UnaryInPlace: + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + case MethodType.Binary: + case MethodType.Contraction: + return $"{tensorType}<{typeName}>"; + case MethodType.Comparison: + return $"{tensorType}"; + default: + throw new ArgumentException(); + } + } + + public string GetLinearOperationCheck() + { + switch (MethodType) + { + case MethodType.Unary: + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + return $"({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride)"; + case MethodType.Binary: + case MethodType.Comparison: + return $"(({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride) && ({ResultName}.IsReversedStride == {Op2Name}.IsReversedStride))"; + case MethodType.UnaryInPlace: + default: + throw new ArgumentException(); + } + } + + + public string GetElementOperation(string typeName, string access) + { + return GetElementOperation(typeName, access, access, access); + } + + public string GetElementOperation(string typeName, string resultAccess, string leftAccess, string rightAccess) + { + switch (MethodType) + { + case MethodType.Unary: + return $"{ResultName}{resultAccess} = ({typeName}){Operator}{Op1Name}{leftAccess}"; + case MethodType.UnaryInPlace: + return $"{ResultName}{resultAccess}{Operator}"; + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name})"; + case MethodType.Binary: + return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess})"; + case MethodType.Comparison: + return $"{ResultName}{resultAccess} = {Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess}"; + default: + throw new ArgumentException(); + + } + } + + public string InitializeResult(string typeName) + { + switch (MethodType) + { + case MethodType.UnaryInPlace: + return $"{Op1Name}.Clone()"; + case MethodType.Unary: + case MethodType.BinaryScalar: + case MethodType.BinaryInt: + case MethodType.Binary: + return $"{Op1Name}.CloneEmpty()"; + case MethodType.Comparison: + return $"{Op1Name}.CloneEmpty()"; + case MethodType.Contraction: + return $"{Op1Name}.CloneEmpty(resultDimensions)"; + default: + throw new ArgumentException(); + } + } + + public bool IsNumeric { get; } + public bool IsBitwise { get; } + } + + + public MethodConfiguration[] methodConfiguration = new [] + { + new MethodConfiguration("Add", MethodType.Binary, "+", isNumeric:true), + new MethodConfiguration("Add", MethodType.BinaryScalar, "+", isNumeric:true), + new MethodConfiguration("UnaryPlus", MethodType.Unary, "+", isNumeric:true), + new MethodConfiguration("Subtract", MethodType.Binary, "-", isNumeric:true), + new MethodConfiguration("Subtract", MethodType.BinaryScalar, "-", isNumeric:true), + new MethodConfiguration("UnaryMinus", MethodType.Unary, "-", isNumeric:true), + new MethodConfiguration("Increment", MethodType.UnaryInPlace, "++", isNumeric:true), + new MethodConfiguration("Decrement", MethodType.UnaryInPlace, "--", isNumeric:true), + new MethodConfiguration("Multiply", MethodType.Binary, "*", isNumeric:true), // element-wise product, not matrix product + new MethodConfiguration("Multiply", MethodType.BinaryScalar, "*", isNumeric:true), + new MethodConfiguration("Divide", MethodType.Binary, "/", isNumeric:true), + new MethodConfiguration("Divide", MethodType.BinaryScalar, "/", isNumeric:true), + new MethodConfiguration("Modulo", MethodType.Binary, "%", isNumeric:true), + new MethodConfiguration("Modulo", MethodType.BinaryScalar, "%", isNumeric:true), + new MethodConfiguration("And", MethodType.Binary, "&", isBitwise: true), + new MethodConfiguration("And", MethodType.BinaryScalar, "&", isBitwise: true), + new MethodConfiguration("Or", MethodType.Binary, "|", isBitwise: true), + new MethodConfiguration("Or", MethodType.BinaryScalar, "|", isBitwise: true), + new MethodConfiguration("Xor", MethodType.Binary, "^", isBitwise: true), + new MethodConfiguration("Xor", MethodType.BinaryScalar, "^", isBitwise: true), + new MethodConfiguration("LeftShift", MethodType.BinaryInt, "<<", isBitwise: true), + new MethodConfiguration("RightShift", MethodType.BinaryInt, ">>", isBitwise: true), + + // Note all of these are element-wise operations not testing the operation on the entire Tensor + new MethodConfiguration("Equals", MethodType.Comparison, "=="), + new MethodConfiguration("NotEquals", MethodType.Comparison, "!="), + new MethodConfiguration("GreaterThanOrEqual", MethodType.Comparison, ">=", isNumeric:true), + new MethodConfiguration("LessThanOrEqual", MethodType.Comparison, "<=", isNumeric:true), + new MethodConfiguration("GreaterThan", MethodType.Comparison, ">", isNumeric:true), + new MethodConfiguration("LessThan", MethodType.Comparison, "<", isNumeric:true), + + new MethodConfiguration("Contract", MethodType.Contraction, isNumeric:true), + }.OrderBy(m => m.MethodName).ToArray(); +#> diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTests.cs new file mode 100644 index 0000000000000..272484e1fb24f --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTests.cs @@ -0,0 +1,2243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorTests.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Xunit; +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors.Tests +{ + public class TensorTests : TensorTestsBase + { + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ConstructTensorFromArrayRank1(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray(new[] { 0, 1, 2 }); + + Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); + Assert.Equal(0, tensor[0]); + Assert.Equal(1, tensor[1]); + Assert.Equal(2, tensor[2]); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ConstructTensorFromArrayRank2(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray(new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); + Assert.Equal(0, tensor[0, 0]); + Assert.Equal(1, tensor[0, 1]); + Assert.Equal(2, tensor[0, 2]); + Assert.Equal(3, tensor[1, 0]); + Assert.Equal(4, tensor[1, 1]); + Assert.Equal(5, tensor[1, 2]); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ConstructTensorFromArrayRank3(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray(new[, ,] + { + { + {0, 1, 2}, + {3, 4, 5} + }, + { + {6, 7 ,8 }, + {9, 10 ,11 }, + }, + { + {12, 13 ,14 }, + {15, 16 ,17 }, + }, + { + {18, 19 ,20 }, + {21, 22 ,23 }, + } + }); + + Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); + + Assert.Equal(0, tensor[0, 0, 0]); + Assert.Equal(1, tensor[0, 0, 1]); + Assert.Equal(2, tensor[0, 0, 2]); + Assert.Equal(3, tensor[0, 1, 0]); + Assert.Equal(4, tensor[0, 1, 1]); + Assert.Equal(5, tensor[0, 1, 2]); + + Assert.Equal(6, tensor[1, 0, 0]); + Assert.Equal(7, tensor[1, 0, 1]); + Assert.Equal(8, tensor[1, 0, 2]); + Assert.Equal(9, tensor[1, 1, 0]); + Assert.Equal(10, tensor[1, 1, 1]); + Assert.Equal(11, tensor[1, 1, 2]); + + Assert.Equal(12, tensor[2, 0, 0]); + Assert.Equal(13, tensor[2, 0, 1]); + Assert.Equal(14, tensor[2, 0, 2]); + Assert.Equal(15, tensor[2, 1, 0]); + Assert.Equal(16, tensor[2, 1, 1]); + Assert.Equal(17, tensor[2, 1, 2]); + + Assert.Equal(18, tensor[3, 0, 0]); + Assert.Equal(19, tensor[3, 0, 1]); + Assert.Equal(20, tensor[3, 0, 2]); + Assert.Equal(21, tensor[3, 1, 0]); + Assert.Equal(22, tensor[3, 1, 1]); + Assert.Equal(23, tensor[3, 1, 2]); + } + + [Fact] + public void ConstructDenseTensorFromPointer() + { + using (var nativeMemory = NativeMemoryFromArray(Enumerable.Range(0, 24).ToArray())) + { + var dimensions = new[] { 4, 2, 3 }; + var tensor = new DenseTensor(nativeMemory.Memory, dimensions, false); + + Assert.Equal(0, tensor[0, 0, 0]); + Assert.Equal(1, tensor[0, 0, 1]); + Assert.Equal(2, tensor[0, 0, 2]); + Assert.Equal(3, tensor[0, 1, 0]); + Assert.Equal(4, tensor[0, 1, 1]); + Assert.Equal(5, tensor[0, 1, 2]); + + Assert.Equal(6, tensor[1, 0, 0]); + Assert.Equal(7, tensor[1, 0, 1]); + Assert.Equal(8, tensor[1, 0, 2]); + Assert.Equal(9, tensor[1, 1, 0]); + Assert.Equal(10, tensor[1, 1, 1]); + Assert.Equal(11, tensor[1, 1, 2]); + + Assert.Equal(12, tensor[2, 0, 0]); + Assert.Equal(13, tensor[2, 0, 1]); + Assert.Equal(14, tensor[2, 0, 2]); + Assert.Equal(15, tensor[2, 1, 0]); + Assert.Equal(16, tensor[2, 1, 1]); + Assert.Equal(17, tensor[2, 1, 2]); + + Assert.Equal(18, tensor[3, 0, 0]); + Assert.Equal(19, tensor[3, 0, 1]); + Assert.Equal(20, tensor[3, 0, 2]); + Assert.Equal(21, tensor[3, 1, 0]); + Assert.Equal(22, tensor[3, 1, 1]); + Assert.Equal(23, tensor[3, 1, 2]); + } + } + + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ConstructFromDimensions(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromDimensions(new[] { 2, 3, 4 }); + Assert.Equal(3, tensor.Rank); + Assert.Equal(3, tensor.Dimensions.Length); + Assert.Equal(2, tensor.Dimensions[0]); + Assert.Equal(3, tensor.Dimensions[1]); + Assert.Equal(4, tensor.Dimensions[2]); + Assert.Equal(24, tensor.Length); + Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); + + //Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: null)); + Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new int[0])); + + Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new[] { 1, 0 })); + Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new[] { 1, -1 })); + + // ensure dimensions are immutable + var dimensions = new[] { 1, 2, 3 }; + tensor = tensorConstructor.CreateFromDimensions(dimensions: dimensions); + dimensions[0] = dimensions[1] = dimensions[2] = 0; + Assert.Equal(1, tensor.Dimensions[0]); + Assert.Equal(2, tensor.Dimensions[1]); + Assert.Equal(3, tensor.Dimensions[2]); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ConstructTensorFromArrayRank3WithLowerBounds(TensorConstructor tensorConstructor) + { + var dimensions = new[] { 2, 3, 4 }; + var lowerBounds = new[] { 0, 5, 200 }; + var arrayWithLowerBounds = Array.CreateInstance(typeof(int), dimensions, lowerBounds); + + int value = 0; + for (int x = lowerBounds[0]; x < lowerBounds[0] + dimensions[0]; x++) + { + for (int y = lowerBounds[1]; y < lowerBounds[1] + dimensions[1]; y++) + { + for (int z = lowerBounds[2]; z < lowerBounds[2] + dimensions[2]; z++) + { + arrayWithLowerBounds.SetValue(value++, x, y, z); + } + } + } + + var tensor = tensorConstructor.CreateFromArray(arrayWithLowerBounds); + + var expected = tensorConstructor.CreateFromArray(new[, ,] + { + { + { 0, 1, 2, 3 }, + { 4, 5, 6, 7 }, + { 8, 9, 10, 11 } + }, + { + { 12, 13, 14, 15 }, + { 16, 17, 18, 19 }, + { 20, 21, 22, 23 } + } + } + ); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(expected, tensor)); + Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void StructurallyEqualTensor(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var arr = new[, ,] + { + { + {0, 1, 2}, + {3, 4, 5} + }, + { + {6, 7 ,8 }, + {9, 10 ,11 }, + }, + { + {12, 13 ,14 }, + {15, 16 ,17 }, + }, + { + {18, 19 ,20 }, + {21, 22 ,23 }, + } + }; + var tensor = leftConstructor.CreateFromArray(arr); + var tensor2 = rightConstructor.CreateFromArray(arr); + + Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor, tensor2)); + Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor2, tensor)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, tensor2)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor2, tensor)); + // Issue: should Tensors with different layout be structurally equal? + if (leftConstructor.IsReversedStride == leftConstructor.IsReversedStride) + { + Assert.Equal(StructuralComparisons.StructuralEqualityComparer.GetHashCode(tensor), StructuralComparisons.StructuralEqualityComparer.GetHashCode(tensor2)); + } + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void StructurallyEqualArray(TensorConstructor tensorConstructor) + { + var arr = new[, ,] + { + { + {0, 1, 2}, + {3, 4, 5} + }, + { + {6, 7 ,8 }, + {9, 10 ,11 }, + }, + { + {12, 13 ,14 }, + {15, 16 ,17 }, + }, + { + {18, 19 ,20 }, + {21, 22 ,23 }, + } + }; + var tensor = tensorConstructor.CreateFromArray(arr); + + Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor, arr)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, arr)); + + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetDiagonalSquare(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var diag = tensor.GetDiagonal(); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1, 3, 5 })); + diag = tensor.GetDiagonal(1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 2, 9 })); + diag = tensor.GetDiagonal(2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 4 })); + Assert.Throws("offset", () => tensor.GetDiagonal(3)); + + diag = tensor.GetDiagonal(-1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 8, 7 })); + diag = tensor.GetDiagonal(-2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1 })); + Assert.Throws("offset", () => tensor.GetDiagonal(-3)); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetDiagonalRectangle(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var diag = tensor.GetDiagonal(); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1, 3, 5 })); + diag = tensor.GetDiagonal(1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 2, 9, 2 })); + diag = tensor.GetDiagonal(2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 4, 2, 9 })); + diag = tensor.GetDiagonal(3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 3, 6 })); + diag = tensor.GetDiagonal(4); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 7 })); + Assert.Throws("offset", () => tensor.GetDiagonal(5)); + + diag = tensor.GetDiagonal(-1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 8, 7 })); + diag = tensor.GetDiagonal(-2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1 })); + Assert.Throws("offset", () => tensor.GetDiagonal(-3)); + Assert.Throws("offset", () => tensor.GetDiagonal(-4)); + Assert.Throws("offset", () => tensor.GetDiagonal(-5)); + } + + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetDiagonalCube(TensorConstructor tensorConstructor) + { + var arr = new[, ,] + { + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }, + { + { 4, 5, 7 }, + { 1, 6, 2 }, + { 3, 0, 8 }, + }, + { + { 5, 6, 1 }, + { 2, 2, 3 }, + { 4, 9, 4 }, + }, + + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var diag = tensor.GetDiagonal(); + var expected = new[,] + { + { 1, 2, 4 }, + { 1, 6, 2 }, + { 4, 9, 4 } + }; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, diag.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetTriangleSquare(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetTriangle(0); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + + var expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 0, 0 }, + { 8, 3, 0 }, + { 1, 7, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 0 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(200); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(-1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0 }, + { 8, 0, 0 }, + { 1, 7, 0 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(-2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 1, 0, 0 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + }); + tri = tensor.GetTriangle(-3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + // same as -3, should it be an exception? + tri = tensor.GetTriangle(-4); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(-300); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetTriangleRectangle(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetTriangle(0); + var expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 0, 0, 0, 0 }, + { 8, 3, 0, 0, 0 }, + { 1, 7, 5, 0, 0 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + + tri = tensor.GetTriangle(1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 0, 0, 0 }, + { 8, 3, 9, 0, 0 }, + { 1, 7, 5, 2, 0 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 0, 0 }, + { 8, 3, 9, 2, 0 }, + { 1, 7, 5, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(3); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 3, 0 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(4); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + // same as 4, should it be an exception? + tri = tensor.GetTriangle(5); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(1000); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(-1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 0, 0 }, + { 8, 0, 0, 0, 0 }, + { 1, 7, 0, 0, 0 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 }, + { 1, 0, 0, 0, 0 } + }); + tri = tensor.GetTriangle(-2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 } + }); + tri = tensor.GetTriangle(-3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetTriangle(-4); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(-5); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetTriangle(-100); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetTriangleCube(TensorConstructor tensorConstructor) + { + var arr = new[, ,] + { + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }, + { + { 4, 5, 7 }, + { 1, 6, 2 }, + { 3, 0, 8 }, + }, + { + { 5, 6, 1 }, + { 2, 2, 3 }, + { 4, 9, 4 }, + }, + + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetTriangle(0); + var expected = tensorConstructor.CreateFromArray(new[, ,] + { + { + { 1, 2, 4 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + }, + { + { 4, 5, 7 }, + { 1, 6, 2 }, + { 0, 0, 0 }, + }, + { + { 5, 6, 1 }, + { 2, 2, 3 }, + { 4, 9, 4 }, + }, + + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetUpperTriangleSquare(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetUpperTriangle(0); + + var expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4 }, + { 0, 3, 9 }, + { 0, 0, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + + tri = tensor.GetUpperTriangle(1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 2, 4 }, + { 0, 0, 9 }, + { 0, 0, 0 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 4 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(3); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(4); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(42); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(-1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 0, 7, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(-2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(-3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(-300); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetUpperTriangleRectangle(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetUpperTriangle(0); + var expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 3, 7 }, + { 0, 3, 9, 2, 6 }, + { 0, 0, 5, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + tri = tensor.GetUpperTriangle(1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 2, 4, 3, 7 }, + { 0, 0, 9, 2, 6 }, + { 0, 0, 0, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(2); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 4, 3, 7 }, + { 0, 0, 0, 2, 6 }, + { 0, 0, 0, 0, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(3); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 3, 7 }, + { 0, 0, 0, 0, 6 }, + { 0, 0, 0, 0, 0 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(4); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 0, 7 }, + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + expected = tensorConstructor.CreateFromArray(new[,] + { + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 }, + { 0, 0, 0, 0, 0 } + }); + tri = tensor.GetUpperTriangle(5); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(6); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(1000); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(-1); + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 0, 7, 5, 2, 9 } + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + expected = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 4, 3, 7 }, + { 8, 3, 9, 2, 6 }, + { 1, 7, 5, 2, 9 } + }); + tri = tensor.GetUpperTriangle(-2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + + tri = tensor.GetUpperTriangle(-3); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(-4); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + tri = tensor.GetUpperTriangle(-100); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetUpperTriangleCube(TensorConstructor tensorConstructor) + { + var arr = new[, ,] + { + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }, + { + { 4, 5, 7 }, + { 1, 6, 2 }, + { 3, 0, 8 }, + }, + { + { 5, 6, 1 }, + { 2, 2, 3 }, + { 4, 9, 4 }, + }, + + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var tri = tensor.GetUpperTriangle(0); + var expected = tensorConstructor.CreateFromArray(new[, ,] + { + { + { 1, 2, 4 }, + { 8, 3, 9 }, + { 1, 7, 5 }, + }, + { + { 0, 0, 0 }, + { 1, 6, 2 }, + { 3, 0, 8 }, + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 4, 9, 4 }, + }, + + }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void Reshape(TensorConstructor tensorConstructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = tensorConstructor.CreateFromArray(arr); + var actual = tensor.Reshape(new[] { 3, 2 }); + + var expected = tensorConstructor.IsReversedStride ? + new[,] + { + { 1, 5 }, + { 4, 3 }, + { 2, 6 } + } : + new[,] + { + { 1, 2 }, + { 3, 4 }, + { 5, 6 } + }; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Fact] + public void Identity() + { + var actual = Tensor.CreateIdentity(3); + + var expected = new[,] + { + {1.0, 0, 0 }, + {0, 1.0, 0 }, + {0, 0, 1.0 } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void CreateWithDiagonal(TensorConstructor tensorConstructor) + { + var diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4, 5 }); + var actual = Tensor.CreateFromDiagonal(diagonal); + + var expected = new[,] + { + {1, 0, 0, 0, 0 }, + {0, 2, 0, 0, 0 }, + {0, 0, 3, 0, 0 }, + {0, 0, 0, 4, 0 }, + {0, 0, 0, 0, 5 } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void CreateWithDiagonal3D(TensorConstructor tensorConstructor) + { + var diagonal = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 3, 4, 5 }, + { 1, 2, 3, 4, 5 }, + { 1, 2, 3, 4, 5 } + }); + var actual = Tensor.CreateFromDiagonal(diagonal); + var expected = new[, ,] + { + { + {1, 2, 3, 4, 5 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 } + }, + { + {0, 0, 0, 0, 0 }, + {1, 2, 3, 4, 5 }, + {0, 0, 0, 0, 0 } + }, + { + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {1, 2, 3, 4, 5 } + } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void CreateWithDiagonalAndOffset(TensorConstructor tensorConstructor) + { + var diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4 }); + var actual = Tensor.CreateFromDiagonal(diagonal, 1); + + var expected = new[,] + { + {0, 1, 0, 0, 0 }, + {0, 0, 2, 0, 0 }, + {0, 0, 0, 3, 0 }, + {0, 0, 0, 0, 4 }, + {0, 0, 0, 0, 0 } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4 }); + actual = Tensor.CreateFromDiagonal(diagonal, -1); + + expected = new[,] + { + {0, 0, 0, 0, 0 }, + {1, 0, 0, 0, 0 }, + {0, 2, 0, 0, 0 }, + {0, 0, 3, 0, 0 }, + {0, 0, 0, 4, 0 } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[] { 1 }); + actual = Tensor.CreateFromDiagonal(diagonal, -4); + expected = new[,] + { + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {1, 0, 0, 0, 0 } + }; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[] { 1 }); + actual = Tensor.CreateFromDiagonal(diagonal, 4); + expected = new[,] + { + {0, 0, 0, 0, 1 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 }, + {0, 0, 0, 0, 0 } + }; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void CreateWithDiagonalAndOffset3D(TensorConstructor tensorConstructor) + { + var diagonal = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 3 }, + { 1, 2, 3 }, + { 1, 2, 3 } + }); + var actual = Tensor.CreateFromDiagonal(diagonal, 1); + + var expected = new[, ,] + { + { + { 0, 0, 0 }, + { 1, 2, 3 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 1, 2, 3 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 1, 2, 3 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 3 }, + { 1, 2, 3 }, + { 1, 2, 3 } + }); + actual = Tensor.CreateFromDiagonal(diagonal, -1); + + expected = new[, ,] + { + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 1, 2, 3 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 1, 2, 3 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 1, 2, 3 }, + { 0, 0, 0 } + } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 3 } + }); + actual = Tensor.CreateFromDiagonal(diagonal, 3); + + expected = new[, ,] + { + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 1, 2, 3 }, + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + diagonal = tensorConstructor.CreateFromArray(new[,] + { + { 1, 2, 3 } + }); + actual = Tensor.CreateFromDiagonal(diagonal, -3); + + expected = new[, ,] + { + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + }, + { + { 1, 2, 3 }, + { 0, 0, 0 }, + { 0, 0, 0 }, + { 0, 0, 0 } + } + }; + + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Add(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var right = rightConstructor.CreateFromArray( + new[,] + { + { 6, 7 ,8 }, + { 9, 10 ,11 }, + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + { 6, 8, 10 }, + { 12, 14, 16 }, + }); + + var actual = TensorOperations.Add(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void AddScalar(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 }, + }); + + var actual = TensorOperations.Add(tensor, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void UnaryPlus(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensor; + + var actual = TensorOperations.UnaryPlus(tensor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.False(ReferenceEquals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Subtract(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var right = rightConstructor.CreateFromArray( + new[,] + { + { 6, 7 ,8 }, + { 9, 10 ,11 }, + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + { -6, -6, -6 }, + { -6, -6, -6}, + }); + + var actual = TensorOperations.Subtract(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void SubtractScalar(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var expected = tensorConstructor.CreateFromArray( + new[,] + { + { -1, 0, 1 }, + { 2, 3, 4 }, + }); + + var actual = TensorOperations.Subtract(tensor, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void UnaryMinus(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, -1, -2}, + {-3, -4, -5} + }); + + var actual = TensorOperations.UnaryMinus(tensor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.False(ReferenceEquals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void PrefixIncrement(TensorConstructor tensorConstructor) + { + Tensor tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expectedResult = tensorConstructor.CreateFromArray( + new[,] + { + {1, 2, 3}, + {4, 5, 6} + }); + + var expectedTensor = expectedResult; + + tensor = TensorOperations.Increment(tensor); + var actual = tensor; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); + Assert.True(ReferenceEquals(tensor, actual)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void PostfixIncrement(TensorConstructor tensorConstructor) + { + Tensor tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + // returns original value + var expectedResult = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + // increments operand + var expectedTensor = tensorConstructor.CreateFromArray( + new[,] + { + {1, 2, 3}, + {4, 5, 6} + }); ; + + var actual = tensor; + tensor = TensorOperations.Increment(tensor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); + Assert.False(ReferenceEquals(tensor, actual)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void PrefixDecrement(TensorConstructor tensorConstructor) + { + Tensor tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expectedResult = tensorConstructor.CreateFromArray( + new[,] + { + {-1, 0, 1}, + {2, 3, 4} + }); + + var expectedTensor = expectedResult; + + tensor = TensorOperations.Decrement(tensor); + var actual = tensor; + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); + Assert.True(ReferenceEquals(tensor, actual)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void PostfixDecrement(TensorConstructor tensorConstructor) + { + Tensor tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + // returns original value + var expectedResult = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + // decrements operand + var expectedTensor = tensorConstructor.CreateFromArray( + new[,] + { + {-1, 0, 1}, + {2, 3, 4} + }); ; + + var actual = tensor; + tensor = TensorOperations.Decrement(tensor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); + Assert.False(ReferenceEquals(tensor, actual)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Multiply(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var right = rightConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 4}, + {9, 16, 25} + }); + + var actual = TensorOperations.Multiply(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void MultiplyScalar(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 2, 4}, + {6, 8, 10} + }); + + var actual = TensorOperations.Multiply(tensor, 2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Divide(TensorConstructor dividendConstructor, TensorConstructor divisorConstructor) + { + var dividend = dividendConstructor.CreateFromArray( + new[,] + { + {0, 1, 4}, + {9, 16, 25} + }); + + var divisor = divisorConstructor.CreateFromArray( + new[,] + { + {1, 1, 2}, + {3, 4, 5} + }); + + var expected = divisorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var actual = TensorOperations.Divide(dividend, divisor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(dividendConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void DivideScalar(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 2, 4}, + {6, 8, 10} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var actual = TensorOperations.Divide(tensor, 2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Modulo(TensorConstructor dividendConstructor, TensorConstructor divisorConstructor) + { + var dividend = dividendConstructor.CreateFromArray( + new[,] + { + {0, 3, 8}, + {11, 14, 17} + }); + + var divisor = divisorConstructor.CreateFromArray( + new[,] + { + {1, 2, 3}, + {4, 5, 6} + }); + + var expected = dividendConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var actual = TensorOperations.Modulo(dividend, divisor); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(dividendConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void ModuloScalar(TensorConstructor tensorConstructor) + { + var tensor = tensorConstructor.CreateFromArray( + new[,] + { + {0, 3, 4}, + {7, 8, 9} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 0}, + {1, 0, 1} + }); + + var actual = TensorOperations.Modulo(tensor, 2); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void And(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 3}, + {7, 15, 31} + }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + {1, 1, 3}, + {2, 4, 8} + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 3}, + {2, 4, 8} + }); + + var actual = TensorOperations.And(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void AndScalar(TensorConstructor tensorConstructor) + { + var left = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 3}, + {5, 15, 31} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 0, 0}, + {4, 4, 20} + }); + + var actual = TensorOperations.And(left, 20); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Or(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 3}, + {7, 14, 31} + }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + {1, 2, 4}, + {2, 4, 8} + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {1, 3, 7}, + {7, 14, 31} + }); + + var actual = TensorOperations.Or(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void OrScalar(TensorConstructor tensorConstructor) + { + var left = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {1, 1, 3}, + {3, 5, 5} + }); + + var actual = TensorOperations.Or(left, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void Xor(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 3}, + {7, 14, 31} + }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + {1, 2, 4}, + {2, 4, 8} + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {1, 3, 7}, + {5, 10, 23} + }); + + var actual = TensorOperations.Xor(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void XorScalar(TensorConstructor tensorConstructor) + { + var left = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {1, 0, 3}, + {2, 5, 4} + }); + + var actual = TensorOperations.Xor(left, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void LeftShift(TensorConstructor tensorConstructor) + { + var left = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 2, 4}, + {6, 8, 10} + }); + + var actual = TensorOperations.LeftShift(left, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetSingleTensorConstructors))] + public void RightShift(TensorConstructor tensorConstructor) + { + var left = tensorConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var expected = tensorConstructor.CreateFromArray( + new[,] + { + {0, 0, 1}, + {1, 2, 2} + }); + + var actual = TensorOperations.RightShift(left, 1); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void ElementWiseEquals(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var right = rightConstructor.CreateFromArray( + new[,] + { + {0, 1, -2}, + {2, 3, 5} + }); + + var expected = new[,] + { + {true, true, false }, + {false, false, true} + }.ToTensor(); + + var actual = TensorOperations.Equals(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory()] + [MemberData(nameof(GetDualTensorConstructors))] + public void ElementWiseNotEquals(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + var right = rightConstructor.CreateFromArray( + new[,] + { + {0, 1, -2}, + {2, 3, 5} + }); + + var expected = new[,] + { + {false, false, true}, + {true, true, false} + }.ToTensor(); + + var actual = TensorOperations.NotEquals(left, right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); + } + + [Theory] + [MemberData(nameof(GetDualTensorConstructors))] + public void MatrixMultiply(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {0, 1, 2}, + {3, 4, 5} + }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + {0, 1, 2, 3, 4}, + {5, 6, 7, 8, 9}, + {10, 11, 12, 13, 14} + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {0*0 + 1*5 + 2*10, 0*1 + 1*6 + 2*11, 0*2 + 1*7 + 2*12, 0*3 + 1*8 + 2*13, 0*4 + 1*9 + 2*14}, + {3*0 + 4*5 + 5*10, 3*1 + 4*6 + 5*11, 3*2 + 4*7 + 5*12, 3*3 + 4*8 + 5*13, 3*4 + 4*9 + 5*14} + }); + + var actual = left.MatrixMultiply(right); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + + [Theory] + [MemberData(nameof(GetDualTensorConstructors))] + public void Contract(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[, ,] + { + { + {0, 1}, + {2, 3} + }, + { + {4, 5}, + {6, 7} + }, + { + {8, 9}, + {10, 11} + } + }); + + var right = rightConstructor.CreateFromArray( + new[, ,] + { + { + {0, 1}, + {2, 3}, + {4, 5} + }, + { + {6, 7}, + {8, 9}, + {10, 11} + }, + { + {12, 13}, + {14, 15}, + {16, 17} + }, + { + {18, 19}, + {20, 21}, + {22, 23} + } + }); + + // contract a 3*2*2 with a 4*3*2 tensor, summing on (3*2)*2 and 4*(3*2) to produce a 2*4 tensor + var expected = leftConstructor.CreateFromArray( + new[,] + { + {110, 290, 470, 650}, + {125, 341, 557, 773}, + }); + var actual = TensorOperations.Contract(left, right, new[] { 0, 1 }, new[] { 1, 2 }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + // contract a 3*2*2 with a 4*3*2 tensor, summing on (3)*2*(2) and 4*(3*2) to produce a 2*4 tensor + expected = leftConstructor.CreateFromArray( + new[,] + { + {101, 263, 425, 587}, + {131, 365, 599, 833}, + }); + actual = TensorOperations.Contract(left, right, new[] { 0, 2 }, new[] { 1, 2 }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + + [Theory] + [MemberData(nameof(GetDualTensorConstructors))] + public void ContractWithSingleLengthDimension(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[,] + { + {1, 2, 3}, + {4, 5, 6}, + }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + { 1, 2 }, + { 3, 4 }, + { 5, 6 } + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + { 22, 28 }, + { 49, 64 } + }); + + // contract a 2*3 with a 3*2 tensor, summing on 2*(3) and (3)*2 to produce a 2*2 tensor + var actual = TensorOperations.Contract(left, right, new[] { 1 }, new[] { 0 }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + + // contract a 1*2*3*1 with a 3*2 tensor, summing on 1*2*(3)*1 and (3)*2 to produce a 1*2*1*2 tensor + var reshapedLeft = left.Reshape(new int[] { 1, 2, 3, 1 }); + var reshapedExpected = expected.Reshape(new int[] { 1, 2, 1, 2 }); + actual = TensorOperations.Contract(reshapedLeft, right, new[] { 2 }, new[] { 0 }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, reshapedExpected)); + + } + + [Theory] + [MemberData(nameof(GetDualTensorConstructors))] + public void ContractMismatchedDimensions(TensorConstructor leftConstructor, TensorConstructor rightConstructor) + { + var left = leftConstructor.CreateFromArray( + new[] { 0, 1, 2, 3 }); + + var right = rightConstructor.CreateFromArray( + new[,] + { + { 0 }, + { 1 }, + { 2 } + }); + + var expected = leftConstructor.CreateFromArray( + new[,] + { + {0,0,0}, + {0,1,2}, + {0,2,4}, + {0,3,6}, + }); + + Assert.Throws(() => TensorOperations.Contract(left, right, new int[] { }, new[] { 1 })); + + // reshape to include dimension of length 1. + var leftReshaped = left.Reshape(new[] { 1, (int)left.Length }); + + var actual = TensorOperations.Contract(leftReshaped, right, new[] { 0 }, new[] { 1 }); + Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void GetArrayString(TensorConstructor constructor) + { + var tensor = constructor.CreateFromArray( + new[, ,] + { + { + {0, 1}, + {2, 3}, + {4, 5} + }, + { + {6, 7}, + {8, 9}, + {10, 11} + }, + { + {12, 13}, + {14, 15}, + {16, 17} + }, + { + {18, 19}, + {20, 21}, + {22, 23} + } + }); + + var expected = +@"{ + { + {0,1}, + {2,3}, + {4,5} + }, + { + {6,7}, + {8,9}, + {10,11} + }, + { + {12,13}, + {14,15}, + {16,17} + }, + { + {18,19}, + {20,21}, + {22,23} + } +}"; + + Assert.Equal(expected, tensor.GetArrayString()); + + var expectedNoSpace = expected.Replace(Environment.NewLine, "").Replace(" ", ""); + Assert.Equal(expectedNoSpace, tensor.GetArrayString(false)); + } + + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void TestICollectionMembers(TensorConstructor constructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = constructor.CreateFromArray(arr); + ICollection tensorCollection = tensor; + + Assert.Equal(6, tensorCollection.Count); + + Assert.False(tensorCollection.IsSynchronized); + + Assert.True(ReferenceEquals(tensorCollection, tensorCollection.SyncRoot)); + + var actual = Array.CreateInstance(typeof(int), tensor.Length); + tensorCollection.CopyTo(actual, 0); + var expected = constructor.IsReversedStride ? + new[] { 1, 4, 2, 5, 3, 6 } : + new[] { 1, 2, 3, 4, 5, 6 }; + Assert.Equal(expected, actual); + + actual = Array.CreateInstance(typeof(int), tensor.Length + 2); + tensorCollection.CopyTo(actual, 2); + expected = constructor.IsReversedStride ? + new[] { 0, 0, 1, 4, 2, 5, 3, 6 } : + new[] { 0, 0, 1, 2, 3, 4, 5, 6 }; + Assert.Equal(expected, actual); + + Assert.Throws(() => tensorCollection.CopyTo(null, 0)); + Assert.Throws(() => tensorCollection.CopyTo(new int[3, 4], 0)); + Assert.Throws(() => tensorCollection.CopyTo(new int[5], 0)); + Assert.Throws(() => tensorCollection.CopyTo(new int[6], 1)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void TestIListMembers(TensorConstructor constructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = constructor.CreateFromArray(arr); + IList tensorList = tensor; + + int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; + Assert.Equal(expectedIndexValue, tensorList[1]); + + tensorList[1] = 7; + Assert.Equal(7, tensorList[1]); + var expected = constructor.IsReversedStride ? + new[] { 1, 7, 2, 5, 3, 6 } : + new[] { 1, 7, 3, 4, 5, 6 }; + Assert.Equal(expected, tensor); + + Assert.True(tensorList.IsFixedSize); + Assert.False(tensorList.IsReadOnly); + + Assert.Throws(() => (tensorList).Add(8)); + + Assert.True(tensorList.Contains(5)); + Assert.True(tensorList.Contains(6)); + Assert.False(tensorList.Contains(0)); + Assert.False(tensorList.Contains(42)); + Assert.False(tensorList.Contains("foo")); + + Assert.Equal(constructor.IsReversedStride ? 3 : 4, tensorList.IndexOf(5)); + Assert.Equal(5, tensorList.IndexOf(6)); + Assert.Equal(-1, tensorList.IndexOf(0)); + Assert.Equal(-1, tensorList.IndexOf(42)); + + Assert.Throws(() => (tensorList).Insert(2, 5)); + Assert.Throws(() => (tensorList).Remove(1)); + Assert.Throws(() => (tensorList).RemoveAt(0)); + + tensorList.Clear(); + Assert.Equal(new[] { 0, 0, 0, 0, 0, 0 }, tensor); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void TestICollectionTMembers(TensorConstructor constructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = constructor.CreateFromArray(arr); + ICollection tensorCollection = tensor; + + Assert.Equal(6, tensorCollection.Count); + Assert.False(tensorCollection.IsReadOnly); + + Assert.Throws(() => tensorCollection.Add(8)); + Assert.Throws(() => tensorCollection.Remove(1)); + + Assert.True(tensorCollection.Contains(5)); + Assert.True(tensorCollection.Contains(6)); + Assert.False(tensorCollection.Contains(0)); + Assert.False(tensorCollection.Contains(42)); + + var actual = new int[tensor.Length]; + tensorCollection.CopyTo(actual, 0); + var expected = constructor.IsReversedStride ? + new[] { 1, 4, 2, 5, 3, 6 } : + new[] { 1, 2, 3, 4, 5, 6 }; + Assert.Equal(expected, actual); + + actual = new int[tensor.Length + 2]; + tensorCollection.CopyTo(actual, 2); + expected = constructor.IsReversedStride ? + new[] { 0, 0, 1, 4, 2, 5, 3, 6 } : + new[] { 0, 0, 1, 2, 3, 4, 5, 6 }; + Assert.Equal(expected, actual); + + Assert.Throws(() => tensorCollection.CopyTo(null, 0)); + Assert.Throws(() => tensorCollection.CopyTo(new int[5], 0)); + Assert.Throws(() => tensorCollection.CopyTo(new int[6], 1)); + + tensorCollection.Clear(); + Assert.Equal(new[] { 0, 0, 0, 0, 0, 0 }, tensor); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void TestIListTMembers(TensorConstructor constructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = constructor.CreateFromArray(arr); + IList tensorList = tensor; + + int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; + Assert.Equal(expectedIndexValue, tensorList[1]); + + tensorList[1] = 7; + Assert.Equal(7, tensorList[1]); + var expected = constructor.IsReversedStride ? + new[] { 1, 7, 2, 5, 3, 6 } : + new[] { 1, 7, 3, 4, 5, 6 }; + Assert.Equal(expected, tensor); + + Assert.Equal(constructor.IsReversedStride ? 3 : 4, tensorList.IndexOf(5)); + Assert.Equal(5, tensorList.IndexOf(6)); + Assert.Equal(-1, tensorList.IndexOf(0)); + Assert.Equal(-1, tensorList.IndexOf(42)); + + Assert.Throws(() => (tensorList).Insert(2, 5)); + Assert.Throws(() => (tensorList).RemoveAt(0)); + } + + [Theory] + [MemberData(nameof(GetSingleTensorConstructors))] + public void TestIReadOnlyTMembers(TensorConstructor constructor) + { + var arr = new[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }; + + var tensor = constructor.CreateFromArray(arr); + + IReadOnlyCollection tensorCollection = tensor; + Assert.Equal(6, tensorCollection.Count); + + IReadOnlyList tensorList = tensor; + int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; + Assert.Equal(expectedIndexValue, tensorList[1]); + } + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTestsBase.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTestsBase.cs new file mode 100644 index 0000000000000..f7b2ac774e650 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Tensors/TensorTestsBase.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file is copied and adapted from the following git repository - +// https://github.com/dotnet/corefx +// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 +// Path: /src/System.Numerics.Tensors/tests/TensorTestsBase.cs +// Original license statement below - + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System; + +namespace Microsoft.ML.OnnxRuntime.Tensors.Tests +{ + public class TensorTestsBase + { + public enum TensorType + { + Dense + }; + + public class TensorConstructor + { + public TensorType TensorType { get; set; } + + public bool IsReversedStride { get; set; } + + public Tensor CreateFromArray(Array array) + { + switch (TensorType) + { + case TensorType.Dense: + return array.ToTensor(IsReversedStride); + } + + throw new ArgumentException(nameof(TensorType)); + } + public Tensor CreateFromDimensions(ReadOnlySpan dimensions) + { + switch (TensorType) + { + case TensorType.Dense: + return new DenseTensor(dimensions, IsReversedStride); + } + + throw new ArgumentException(nameof(TensorType)); + } + + public override string ToString() + { + return $"{TensorType}, {nameof(IsReversedStride)} = {IsReversedStride}"; + } + } + + private static TensorType[] s_tensorTypes = new[] + { + TensorType.Dense + }; + + private static bool[] s_reverseStrideValues = new[] + { + false, + true + }; + + public static IEnumerable GetSingleTensorConstructors() + { + foreach (TensorType tensorType in s_tensorTypes) + { + foreach (bool isReversedStride in s_reverseStrideValues) + { + yield return new[] + { + new TensorConstructor() + { + TensorType = tensorType, + IsReversedStride = isReversedStride + } + }; + } + } + } + + public static IEnumerable GetDualTensorConstructors() + { + foreach (TensorType leftTensorType in s_tensorTypes) + { + foreach (TensorType rightTensorType in s_tensorTypes) + { + foreach (bool isLeftReversedStride in s_reverseStrideValues) + { + foreach (bool isRightReversedStride in s_reverseStrideValues) + { + yield return new[] + { + new TensorConstructor() + { + TensorType = leftTensorType, + IsReversedStride = isLeftReversedStride + }, + new TensorConstructor() + { + TensorType = rightTensorType, + IsReversedStride = isRightReversedStride + } + }; + } + } + } + } + } + + public static IEnumerable GetTensorAndResultConstructor() + { + foreach (TensorType leftTensorType in s_tensorTypes) + { + foreach (TensorType rightTensorType in s_tensorTypes) + { + foreach (bool isReversedStride in s_reverseStrideValues) + { + yield return new[] + { + new TensorConstructor() + { + TensorType = leftTensorType, + IsReversedStride = isReversedStride + }, + new TensorConstructor() + { + TensorType = rightTensorType, + IsReversedStride = isReversedStride + } + }; + } + } + } + } + + public static NativeMemory NativeMemoryFromArray(T[] array) + { + return NativeMemoryFromArray((Array)array); + } + + public static NativeMemory NativeMemoryFromArray(Array array) + { + // this silly method takes a managed array and copies it over to unmanaged memory, + // **only for test purposes** + + var memory = NativeMemory.Allocate(array.Length); + var span = memory.GetSpan(); + int index = 0; + foreach (T item in array) + { + span[index++] = item; + } + + return memory; + } + } +} diff --git a/csharp/testdata/test_types_BOOL.pb b/csharp/testdata/test_types_BOOL.pb index 005aa79303179a00829b60fa90054a93960f4d30..2c58b06d0aa6d1adcd80239aecec060807492e93 100644 GIT binary patch delta 76 zcmZ3^IGvG~gIS2((bF?8ttio|bs}$wnK&0~W?n&Qi4Y$b4+p0Z2Nx3uBM`GDNpP{{ NmzH3Wabgl+2LM?B4BY?# delta 92 zcmbQvxSWxfgIS2((bF?8ttioI>O|fUH+e4B%)Elq5+N}z5e`lv0WKyEMr53%z{Qqd MT7pM~6O#Zt03MGGi2wiq diff --git a/csharp/testdata/test_types_INT8.pb b/csharp/testdata/test_types_INT8.pb index b6947e990a46ae50b8450502eb3616ef50529bc5..72698a779578d667502629bfbdbb7836e2257a7c 100644 GIT binary patch delta 76 zcmZ3^IGvG~gIS2((bF?8ttio|bs}$wnK&0~W?n&Qi4Y$b4^WbWi;05~h*^^)xY+Vb MOR&f|F$u5(09UXK(EtDd delta 92 zcmbQvxSWxfgIS2((bF?8ttioI>O|fUH+e4B%)Elq5+N}z5ul_17ZV2~GEP$9V#_Zr L!K1>7Nq`*y9Ip+1 diff --git a/csharp/testdata/test_types_STRING.pb b/csharp/testdata/test_types_STRING.pb index 16927acbcc5e3604326067719f330c0c746753db..7c8b3e7e2eb82ada293453d760007b565f92c759 100644 GIT binary patch delta 76 zcmZ3^IGvG~gIS2((bF?8ttio|bs}$wnK&0~W?n&Qi4Y$b4+n=32Nx3uBM`GDNpP{{ NmzH3Wabgl+2LM>e4BG$z delta 92 zcmbQvxSWxfgIS2((bF?8ttioI>O|fUH+e4B%)Elq5+N}z5e^O^0WKyEMr53%z{Qqd MT7pM~6O#Zt03JULhX4Qo diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs index 0e9915ad6696d..79b89aec17acb 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; -using System.Numerics.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; using System.Diagnostics; using CommandLine; @@ -33,7 +33,7 @@ class CommandOptions public bool ParallelExecution { get; set; } = false; [Option('o', "optimization_level", Required = false, HelpText = "Optimization Level. Default is 1, partial optimization.")] - public uint OptimizationLevel { get; set; } = 1; + public GraphOptimizationLevel OptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_BASIC; } class Program @@ -42,7 +42,8 @@ public static void Main(string[] args) { var cmdOptions = Parser.Default.ParseArguments(args); cmdOptions.WithParsed( - options => { + options => + { Run(options); }); } @@ -52,7 +53,7 @@ public static void Run(CommandOptions options) string inputPath = options.InputFile; int iteration = options.IterationCount; bool parallelExecution = options.ParallelExecution; - uint optLevel = options.OptimizationLevel; + GraphOptimizationLevel optLevel = options.OptimizationLevel; Console.WriteLine("Running model {0} in OnnxRuntime:", modelPath); Console.WriteLine("input:{0}", inputPath); Console.WriteLine("iteration count:{0}", iteration); @@ -84,17 +85,17 @@ public static float[] LoadTensorFromFile(string filename) return tensorData.ToArray(); } - static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, uint optLevel) + static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteration, DateTime[] timestamps, bool parallelExecution, GraphOptimizationLevel optLevel) { if (timestamps.Length != (int)TimingPoint.TotalCount) { - throw new ArgumentException("Timestamps array must have "+(int)TimingPoint.TotalCount+" size"); + throw new ArgumentException("Timestamps array must have " + (int)TimingPoint.TotalCount + " size"); } timestamps[(int)TimingPoint.Start] = DateTime.Now; SessionOptions options = new SessionOptions(); - if (parallelExecution) options.DisableSequentialExecution(); - options.SetSessionGraphOptimizationLevel(optLevel); + if (parallelExecution) options.EnableSequentialExecution = false; + options.GraphOptimizationLevel = optLevel; using (var session = new InferenceSession(modelPath, options)) { timestamps[(int)TimingPoint.ModelLoaded] = DateTime.Now; @@ -108,12 +109,12 @@ static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteratio container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); } - + timestamps[(int)TimingPoint.InputLoaded] = DateTime.Now; // Run the inference - for (int i=0; i < iteration; i++) + for (int i = 0; i < iteration; i++) { var results = session.Run(container); // results is an IReadOnlyList container Debug.Assert(results != null); @@ -132,7 +133,7 @@ static void RunModelOnnxRuntime(string modelPath, string inputPath, int iteratio static void PrintUsage() { Console.WriteLine("Usage:\n" - +"dotnet Microsoft.ML.OnnxRuntime.PerfTool " + + "dotnet Microsoft.ML.OnnxRuntime.PerfTool " ); } diff --git a/dockerfiles/Dockerfile.cuda b/dockerfiles/Dockerfile.cuda index 0a537b774873a..0358629b28e8f 100644 --- a/dockerfiles/Dockerfile.cuda +++ b/dockerfiles/Dockerfile.cuda @@ -18,11 +18,13 @@ WORKDIR /code ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:/opt/miniconda/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime with TensorRT -RUN git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/install_common_deps.sh &&\ +ADD scripts /tmp/scripts +RUN /bin/sh /tmp/scripts/install_common_deps.sh && \ + git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + cp onnxruntime/ThirdPartyNotices.txt /code/ThirdPartyNotices.txt &&\ cp onnxruntime/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt &&\ cd onnxruntime &&\ /bin/sh ./build.sh --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_cuda --config Release --build_wheel --update --build --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ cd .. &&\ - rm -rf onnxruntime cmake-3.14.3-Linux-x86_64.tar.gz cmake-3.14.3-Linux-x86_64 + rm -rf onnxruntime cmake-3.14.3-Linux-x86_64 diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index a829e1651de60..3574725654b4c 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -6,7 +6,8 @@ FROM ubuntu:16.04 RUN apt update && \ - apt -y install python3.5 python3-pip zip x11-apps lsb-core wget cpio sudo libboost-python-dev libpng-dev zlib1g-dev git libnuma1 ocl-icd-libopencl1 clinfo libboost-filesystem1.58.0 libboost-thread1.58.0 protobuf-compiler libprotoc-dev libusb-1.0-0-dev && pip3 install numpy networkx opencv-python pytest && locale-gen en_US.UTF-8 && update-locale LANG=en_US.UTF-8 + apt -y install git sudo wget \ + zip x11-apps lsb-core cpio libboost-python-dev libpng-dev zlib1g-dev libnuma1 ocl-icd-libopencl1 clinfo libboost-filesystem1.58.0 libboost-thread1.58.0 protobuf-compiler libprotoc-dev libusb-1.0-0-dev ARG DEVICE=CPU_FP32 ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime @@ -39,7 +40,7 @@ ENV OpenCV_DIR=${INTEL_OPENVINO_DIR}/opencv/share/OpenCV ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/opencv/lib:${INTEL_OPENVINO_DIR}/opencv/share/OpenCV/3rdparty/lib:${LD_LIBRARY_PATH} ENV PATH=${INTEL_CVSDK_DIR}/deployment_tools/model_optimizer:$PATH ENV PYTHONPATH=${INTEL_CVSDK_DIR}/deployment_tools/model_optimizer:$PYTHONPATH -ENV PYTHONPATH=$INTEL_CVSDK_DIR/python/python3.5:${INTEL_CVSDK_DIR}/python/python3.5/ubuntu16:${PYTHONPATH} +# ENV PYTHONPATH=$INTEL_CVSDK_DIR/python/python3.5:${INTEL_CVSDK_DIR}/python/python3.5/ubuntu16:${PYTHONPATH} ENV HDDL_INSTALL_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl/lib:$LD_LIBRARY_PATH @@ -51,16 +52,15 @@ RUN wget https://github.com/intel/compute-runtime/releases/download/19.15.12831/ RUN sudo dpkg -i *.deb && rm -rf *.deb - -RUN mkdir -p /opt/cmake/bin - -ENV PATH /opt/cmake/bin:$PATH ENV LANG en_US.UTF-8 -RUN wget https://github.com/Kitware/CMake/releases/download/v3.13.2/cmake-3.13.2-Linux-x86_64.tar.gz && \ - tar -xf cmake-3.13.2-Linux-x86_64.tar.gz --strip 1 -C /opt/cmake && rm -rf /cmake-3.13.2-Linux-x86_64.tar.gz +WORKDIR /code +ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:$PATH -RUN git clone --recursive -b $ONNXRUNTIME_BRANCH $ONNXRUNTIME_REPO /onnxruntime && \ +ADD scripts /tmp/scripts +RUN /bin/sh /tmp/scripts/install_common_deps.sh && \ + git clone --recursive -b $ONNXRUNTIME_BRANCH $ONNXRUNTIME_REPO /onnxruntime && \ cd /onnxruntime/cmake/external/onnx && python3 setup.py install && \ - cd /onnxruntime && ./build.sh --config RelWithDebInfo --update --build --parallel --use_openvino $DEVICE --build_wheel && pip3 install /onnxruntime/build/Linux/RelWithDebInfo/dist/*-linux_x86_64.whl && rm -rf /onnxruntime - - + cp /onnxruntime/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt && \ + cp /onnxruntime/ThirdPartyNotices.txt /code/ThirdPartyNotices.txt && \ + cd /onnxruntime && ./build.sh --config RelWithDebInfo --update --build --parallel --use_openvino $DEVICE --build_wheel && \ + pip install /onnxruntime/build/Linux/RelWithDebInfo/dist/*-linux_x86_64.whl && rm -rf /onnxruntime cmake-3.14.3-Linux-x86_64 diff --git a/dockerfiles/Dockerfile.server b/dockerfiles/Dockerfile.server index 3eebdf6db036b..29fff56eadc3b 100644 --- a/dockerfiles/Dockerfile.server +++ b/dockerfiles/Dockerfile.server @@ -32,9 +32,8 @@ RUN mkdir -p /onnxruntime/build && \ FROM minimal AS final WORKDIR /onnxruntime/server/ -ENV MODEL_ABSOLUTE_PATH /onnxruntime/model/model.onnx COPY --from=build /onnxruntime/build/Release/onnxruntime_server /onnxruntime/server/ COPY --from=build /onnxruntime/build/Release/libonnxruntime.so.* /lib/ RUN apt-get update \ && apt-get install -y libgomp1 -ENTRYPOINT /onnxruntime/server/onnxruntime_server --model_path $MODEL_ABSOLUTE_PATH +ENTRYPOINT ["/onnxruntime/server/onnxruntime_server"] diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index 1a0f0921136fb..a0880ee68d84e 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -17,11 +17,12 @@ RUN apt-get update &&\ WORKDIR /code ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH} +ADD scripts /tmp/scripts # Prepare onnxruntime repository & build onnxruntime -RUN git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/install_common_deps.sh &&\ +RUN /bin/sh /tmp/scripts/install_common_deps.sh &&\ + git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ cd onnxruntime &&\ /bin/sh ./build.sh --config Release --build_wheel --update --build --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ cd .. &&\ - rm -rf onnxruntime cmake-3.14.3-Linux-x86_64.tar.gz cmake-3.14.3-Linux-x86_64 + rm -rf onnxruntime cmake-3.14.3-Linux-x86_64 diff --git a/dockerfiles/Dockerfile.tensorrt b/dockerfiles/Dockerfile.tensorrt index 6f3df1fbbba81..6ffd3b19f3f1a 100644 --- a/dockerfiles/Dockerfile.tensorrt +++ b/dockerfiles/Dockerfile.tensorrt @@ -17,12 +17,14 @@ RUN apt-get update &&\ WORKDIR /code ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ADD scripts /tmp/scripts # Prepare onnxruntime repository & build onnxruntime with TensorRT -RUN git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/install_common_deps.sh &&\ +RUN /bin/sh /tmp/scripts/install_common_deps.sh && \ + git clone --single-branch --branch ${ONNXRUNTIME_SERVER_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ cp onnxruntime/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt &&\ + cp onnxruntime/ThirdPartyNotices.txt /code/ThirdPartyNotices.txt &&\ cd onnxruntime &&\ /bin/sh ./build.sh --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /workspace/tensorrt --config Release --build_wheel --update --build --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ cd .. &&\ - rm -rf onnxruntime cmake-3.14.3-Linux-x86_64.tar.gz cmake-3.14.3-Linux-x86_64 + rm -rf onnxruntime cmake-3.14.3-Linux-x86_64 diff --git a/dockerfiles/README.md b/dockerfiles/README.md index f395acc2ef6b5..8fdc7959d0437 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -8,7 +8,7 @@ - [OpenVINO](Dockerfile.openvino) - [ONNX Runtime Server](Dockerfile.server) -## Build from Source Version (Preview) +## Build from Source #### Linux 16.04, CPU, Python Bindings 1. Build the docker image from the Dockerfile in this repository. @@ -26,7 +26,7 @@ docker run -it onnxruntime-source ``` -## CUDA Version (Preview) +## CUDA #### Linux 16.04, CUDA 10.0, CuDNN 7 1. Build the docker image from the Dockerfile in this repository. @@ -44,7 +44,7 @@ docker run -it onnxruntime-cuda ``` -## nGraph Version (Preview) +## nGraph (Public Preview) #### Linux 16.04, Python Bindings 1. Build the docker image from the Dockerfile in this repository. @@ -62,7 +62,7 @@ docker run -it onnxruntime-ngraph ``` -## TensorRT Version (Preview) +## TensorRT #### Linux 16.04, TensorRT 5.0.2 1. Build the docker image from the Dockerfile in this repository. @@ -80,7 +80,7 @@ docker run -it onnxruntime-trt ``` -## OpenVINO Version (Preview) +## OpenVINO (Public Preview) #### Linux 16.04, Python Bindings 1. Build the onnxruntime image for all the accelerators supported as below @@ -104,7 +104,7 @@ | MYRIAD_FP16 | Intel MovidiusTM USB sticks | | VAD-M_FP16 | Intel Vision Accelerator Design based on MovidiusTM MyriadX VPUs | -## CPU Version +## CPU 1. Retrieve your docker image in one of the following ways. @@ -122,7 +122,7 @@ docker run -it onnxruntime-cpu ``` -## GPU Version +## GPU 1. Retrieve your docker image in one of the following ways. - Build the docker image from the DockerFile in this repository. @@ -138,7 +138,7 @@ ``` docker run -it --device /dev/dri:/dev/dri onnxruntime-gpu:latest ``` -## Myriad VPU Accelerator Version +## Myriad VPU Accelerator 1. Retrieve your docker image in one of the following ways. - Build the docker image from the DockerFile in this repository. @@ -155,6 +155,7 @@ docker run -it --network host --privileged -v /dev:/dev onnxruntime-myriad:latest ``` +======= ## VAD-M Accelerator Version 1. Retrieve your docker image in one of the following ways. @@ -172,7 +173,7 @@ docker run -it --device --mount type=bind,source=/var/tmp,destination=/var/tmp --device /dev/ion:/dev/ion onnxruntime-hddl:latest ``` -## ONNX Runtime Server (Preview) +## ONNX Runtime Server (Public Preview) #### Linux 16.04 1. Build the docker image from the Dockerfile in this repository @@ -183,7 +184,7 @@ 2. Run the ONNXRuntime server with the image created in step 1 ``` - docker run -v {localModelAbsoluteFolder}:{dockerModelAbsoluteFolder} -e MODEL_ABSOLUTE_PATH={dockerModelAbsolutePath} -p {your_local_port}:8001 {imageName} + docker run -v {localModelAbsoluteFolder}:{dockerModelAbsoluteFolder} -p {your_local_port}:8001 {imageName} --model_path {dockerModelAbsolutePath} ``` 3. Send HTTP requests to the container running ONNX Runtime Server diff --git a/dockerfiles/install_common_deps.sh b/dockerfiles/scripts/install_common_deps.sh similarity index 81% rename from dockerfiles/install_common_deps.sh rename to dockerfiles/scripts/install_common_deps.sh index dab394cb33fe7..173734332b761 100644 --- a/dockerfiles/install_common_deps.sh +++ b/dockerfiles/scripts/install_common_deps.sh @@ -13,13 +13,13 @@ apt-get install -y --no-install-recommends \ # Dependencies: conda wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh -O ~/miniconda.sh --no-check-certificate && /bin/bash ~/miniconda.sh -b -p /opt/miniconda rm ~/miniconda.sh -/opt/miniconda/bin/conda clean -tipsy -find / -type d -name __pycache__ -prune -exec rm -rf {}; +/opt/miniconda/bin/conda clean -ya -conda install -y python=3.6 numpy -conda clean -aqy +/opt/miniconda/bin/conda install -y numpy +/opt/miniconda/bin/conda clean -aqy rm -rf /opt/miniconda/pkgs # Dependencies: cmake sudo wget --quiet https://github.com/Kitware/CMake/releases/download/v3.14.3/cmake-3.14.3-Linux-x86_64.tar.gz tar zxf cmake-3.14.3-Linux-x86_64.tar.gz +rm -rf cmake-3.14.3-Linux-x86_64.tar.gz \ No newline at end of file diff --git a/docs/ONNX_Runtime_Perf_Tuning.md b/docs/ONNX_Runtime_Perf_Tuning.md index b0ac2e0b2c039..b76b82534d34f 100644 --- a/docs/ONNX_Runtime_Perf_Tuning.md +++ b/docs/ONNX_Runtime_Perf_Tuning.md @@ -72,7 +72,7 @@ sess_options.set_graph_optimization_level(2) ``` * sess_options.session_thread_pool_size=2 controls how many thread do you want to use to run your model * sess_options.enable_sequential_execution=True controls whether you want to run operators in your graph sequentially or in parallel. Usually when your model has many branches, set this option to false will give you better performance. -* sess_options.set_graph_optimization_level(2). There are three levels, 0 means disable optimization, 1 means enable optimizations before graph partition, 2 means enable all optimization. +* sess_options.set_graph_optimization_level(2). Default is 1. Please see [onnxruntime_c_api.h](../include/onnxruntime/core/session/onnxruntime_c_api.h#L241) (enum GraphOptimizationLevel) for the full list of all optimization levels. ### MKL_DNN/nGraph/MKL_ML Execution Provider MKL_DNN, MKL_ML and nGraph all depends on openmp for parallization. For those execution providers, we need to use openmp enviroment variable to tune the performance. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md new file mode 100644 index 0000000000000..2cad94aae0481 --- /dev/null +++ b/docs/OperatorKernels.md @@ -0,0 +1,470 @@ +## Supported Operators Data Types +*This file is automatically generated from the + [def files](/onnxruntime/core/providers/cpu/cpu_execution_provider.cc) via [this script](/tools/python/gen_opkernel_doc.py). + Do not modify directly and instead edit operator definitions.* + + + +## Operators implemented by CPUExecutionProvider + +| Op Name | Parameters | OpSet Version | Types Supported | +|---------|------------|---------------|-----------------| +**Operator Domain:** *ai.onnx.ml* +|Abs|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(int32), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(int64), tensor(double)| +|Acos|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float)| +|Acosh|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|Add|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Affine|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|And|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|ArgMax|(*in* data:**T**, *out* reduced:**tensor(int64)**)|1+|**T** = tensor(int32), tensor(float)| +|ArgMin|(*in* data:**T**, *out* reduced:**tensor(int64)**)|1+|**T** = tensor(int32), tensor(float)| +|ArrayFeatureExtractor|(*in* X:**T**, *in* Y:**tensor(int64)**, *out* Z:**T**)|1+|**T** = tensor(string), tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Asin|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float)| +|Asinh|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|Atan|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float)| +|Atanh|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|AveragePool|(*in* X:**T**, *out* Y:**T**)|10+|**T** = tensor(float)| +| | |[7, 9]|**T** = tensor(float)| +|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|[7, 9]|**B** = tensor(float)| +| | ||**X** = tensor(float)| +| | ||**mean** = tensor(float)| +| | ||**scale** = tensor(float)| +| | ||**var** = tensor(float)| +|Binarizer|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Cast|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(string)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[6, 9]|**T1** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|CastMap|(*in* X:**T1**, *out* Y:**T2**)|1+|**T1** = unknown| +| | ||**T2** = tensor(string), tensor(float), tensor(int64)| +|CategoryMapper|(*in* X:**T1**, *out* Y:**T2**)|1+|**T1** = tensor(string), tensor(int64)| +| | ||**T2** = tensor(string), tensor(int64)| +|Ceil|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Clip|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float)| +|Compress|(*in* input:**T**, *in* condition:**T1**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T1** = tensor(bool)| +|Concat|(*in* inputs:**T**, *out* concat_result:**T**)|4+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|ConstantOfShape|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(int64)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Conv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|ConvInteger|(*in* x:**T1**, *in* w:**T2**, *in* x_zero_point:**T1**, *in* w_zero_point:**T2**, *out* y:**T3**)|10+|**T1** = tensor(uint8)| +| | ||**T2** = tensor(uint8)| +| | ||**T3** = tensor(int32)| +|ConvTranspose|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Cos|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float)| +|Cosh|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|Crop|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|DepthToSpace|(*in* input:**T**, *out* output:**T**)|[1, 4]|**T** = tensor(float)| +|DequantizeLinear|(*in* x:**T**, *in* x_scale:**tensor(float)**, *in* x_zero_point:**T**, *out* y:**tensor(float)**)|10+|**x** = tensor(uint8), unknown| +| | ||**x_scale** = tensor(float)| +| | ||**x_zero_point** = tensor(uint8), unknown| +| | ||**y** = tensor(float)| +|DictVectorizer|(*in* X:**T1**, *out* Y:**T2**)|1+|**T1** = unknown| +| | ||**T2** = tensor(string), tensor(float), tensor(int64), tensor(double)| +|Div|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Dropout|(*in* data:**T**, *out* output:**T**, *out* mask:**T**) or (*in* data:**T**, *out* output:**T**, *out* mask:**T1**)|10+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(bool)| +| | |[7, 9]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(bool)| +|DynamicSlice|(*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Elu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Equal|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|11+|**T** = tensor(float)| +| | ||**T1** = tensor(bool)| +| | |7+|**T** = tensor(int32), tensor(bool), tensor(int64)| +| | ||**T1** = tensor(bool)| +|Erf|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|Exp|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float), tensor(double)| +|Expand|(*in* input:**T**, *in* shape:**tensor(int64)**, *out* output:**T**)|8+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|EyeLike|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(uint64), tensor(int32), tensor(float), tensor(int64), tensor(double)| +| | ||**T2** = tensor(uint64), tensor(int32), tensor(float), tensor(int64), tensor(double)| +|FeatureVectorizer|(*in* X:**T1**, *out* Y:**tensor(float)**)|1+|**T1** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Flatten|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[1, 8]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Floor|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|GRU|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(float), tensor(double)| +| | ||**T1** = tensor(int32)| +|Gather|(*in* data:**T**, *in* indices:**Tind**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Gemm|(*in* A:**T**, *in* B:**T**, *in* C:**T**, *out* Y:**T**)|[7, 9]|**T** = tensor(float)| +|GlobalAveragePool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|GlobalLpPool|(*in* X:**T**, *out* Y:**T**)|2+|**T** = tensor(float)| +|GlobalMaxPool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Greater|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|9+|**T** = tensor(int32), tensor(int64)| +| | ||**T1** = tensor(bool)| +| | |[7, 9]|**T** = tensor(float)| +| | ||**T1** = tensor(bool)| +|HardSigmoid|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Hardmax|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Identity|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|If|(*in* cond:**B**, *out* outputs:**V**)|1+|**B** = tensor(bool)| +| | ||**V** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|ImageScaler|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Imputer|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(int64)| +|InstanceNormalization|(*in* input:**T**, *in* scale:**T**, *in* B:**T**, *out* output:**T**)|6+|**T** = tensor(float)| +|IsInf|(*in* X:**T1**, *out* Y:**T2**)|10+|**T1** = tensor(float), tensor(double)| +| | ||**T2** = tensor(bool)| +|IsNaN|(*in* X:**T1**, *out* Y:**T2**)|9+|**T1** = tensor(float), tensor(MLFloat16)| +| | ||**T2** = tensor(bool)| +|LRN|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|LSTM|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *in* initial_c:**T**, *in* P:**T**, *out* Y:**T**, *out* Y_h:**T**, *out* Y_c:**T**)|7+|**T** = tensor(float), tensor(double)| +| | ||**T1** = tensor(int32)| +|LabelEncoder|(*in* X:**T1**, *out* Y:**T2**)|2+|**T1** = tensor(string), tensor(float), tensor(int64)| +| | ||**T2** = tensor(string), tensor(float), tensor(int64)| +| | |[1, 1]|**T1** = tensor(string), tensor(int64)| +| | ||**T2** = tensor(string), tensor(int64)| +|LeakyRelu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Less|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|9+|**T** = tensor(int32), tensor(int64)| +| | ||**T1** = tensor(bool)| +| | |[7, 9]|**T** = tensor(float)| +| | ||**T1** = tensor(bool)| +|LinearClassifier|(*in* X:**T1**, *out* Y:**T2**, *out* Z:**tensor(float)**)|1+|**T1** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +| | ||**T2** = tensor(string), tensor(int64)| +|LinearRegressor|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(float)| +|Log|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float)| +|LogSoftmax|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Loop|(*in* M:**I**, *in* cond:**B**, *in* v_initial:**V**, *out* v_final_and_scan_outputs:**V**)|1+|**B** = tensor(bool)| +| | ||**I** = tensor(int64)| +| | ||**V** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|LpNormalization|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|LpPool|(*in* X:**T**, *out* Y:**T**)|2+|**T** = tensor(float)| +|MatMul|(*in* A:**T**, *in* B:**T**, *out* Y:**T**)|[1, 9]|**T** = tensor(float), tensor(double)| +| | |[9, 9]|**T** = tensor(uint64), tensor(int32), tensor(int64), tensor(uint32)| +|MatMulInteger|(*in* A:**T1**, *in* B:**T2**, *in* a_zero_point:**T1**, *in* b_zero_point:**T2**, *out* Y:**T3**)|10+|**T1** = tensor(uint8)| +| | ||**T2** = tensor(uint8)| +| | ||**T3** = tensor(int32)| +|Max|(*in* data_0:**T**, *out* max:**T**)|8+|**T** = tensor(float), tensor(double)| +| | |[6, 7]|**T** = tensor(float)| +|MaxPool|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *out* Y:**T**, *out* Indices:**I**)|10+|**I** = tensor(int64)| +| | ||**T** = tensor(float)| +| | |[1, 7]|**T** = tensor(float)| +| | |[8, 9]|**I** = tensor(int64)| +| | ||**T** = tensor(float)| +|MaxRoiPool|(*in* X:**T**, *in* rois:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|MaxUnpool|(*in* X:**T1**, *in* I:**T2**, *in* output_shape:**T2**, *out* output:**T1**)|9+|**T1** = tensor(float)| +| | ||**T2** = tensor(int64)| +|Mean|(*in* data_0:**T**, *out* mean:**T**)|8+|**T** = tensor(float)| +| | |[6, 7]|**T** = tensor(float)| +|MeanVarianceNormalization|(*in* X:**T**, *out* Y:**T**) or (*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +| | |[1, 8]|**T** = tensor(float)| +|Min|(*in* data_0:**T**, *out* min:**T**)|8+|**T** = tensor(float)| +| | |[6, 7]|**T** = tensor(float)| +|Mod|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|10+|**T** = tensor(int32), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Mul|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Multinomial|(*in* input:**T1**, *out* output:**T2**)|7+|**T1** = tensor(float)| +| | ||**T2** = tensor(int32), tensor(int64)| +|Neg|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(int32), tensor(float), unknown| +|NonZero|(*in* X:**T**, *out* Y:**tensor(int64)**)|9+|**T** = tensor(int32), tensor(float), tensor(bool), tensor(int64)| +|Normalizer|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Not|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|OneHot|(*in* indices:**T1**, *in* depth:**T2**, *in* values:**T3**, *out* output:**T3**)|9+|**T1** = tensor(int32), tensor(float), tensor(int64)| +| | ||**T2** = tensor(int32), tensor(float), tensor(int64)| +| | ||**T3** = tensor(string), tensor(int32), tensor(float), tensor(int64)| +|OneHotEncoder|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(string), tensor(float), tensor(int64), tensor(double)| +|Or|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|PRelu|(*in* X:**T**, *in* slope:**T**, *out* Y:**T**)|[7, 9]|**T** = tensor(float)| +|Pad|(*in* data:**T**, *out* output:**T**)|2+|**T** = tensor(float)| +|ParametricSoftplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Pow|(*in* X:**T**, *in* Y:**T**, *out* Z:**T**)|7+|**T** = tensor(float), tensor(double)| +|QLinearConv|(*in* x:**T1**, *in* x_scale:**tensor(float)**, *in* x_zero_point:**T1**, *in* w:**T2**, *in* w_scale:**tensor(float)**, *in* w_zero_point:**T2**, *in* y_scale:**tensor(float)**, *in* y_zero_point:**T3**, *in* B:**T4**, *out* y:**T3**)|10+|**T1** = tensor(uint8)| +| | ||**T2** = tensor(uint8)| +| | ||**T3** = tensor(uint8)| +| | ||**T4** = tensor(int32)| +|QLinearMatMul|(*in* a:**T1**, *in* a_scale:**tensor(float)**, *in* a_zero_point:**T1**, *in* b:**T2**, *in* b_scale:**tensor(float)**, *in* b_zero_point:**T2**, *in* y_scale:**tensor(float)**, *in* y_zero_point:**T3**, *out* y:**T3**)|10+|**T1** = tensor(uint8)| +| | ||**T2** = tensor(uint8)| +| | ||**T3** = tensor(uint8)| +|QuantizeLinear|(*in* x:**T1**, *in* y_scale:**tensor(float)**, *in* y_zero_point:**T2**, *out* y:**T2**)|10+|**x** = tensor(float)| +| | ||**y** = tensor(uint8), unknown| +| | ||**y_zero_point** = tensor(uint8), unknown| +|RNN|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(float)| +| | ||**T1** = tensor(int32)| +|RandomNormal|(*out* output:**T**)|1+|**T** = tensor(float), tensor(double)| +|RandomNormalLike|(*in* input:**T1**, *out* output:**T2**)|1+|**T1** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T2** = tensor(float), tensor(double)| +|RandomUniform|(*out* output:**T**)|1+|**T** = tensor(float), tensor(double)| +|RandomUniformLike|(*in* input:**T1**, *out* output:**T2**)|1+|**T1** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T2** = tensor(float), tensor(double)| +|Reciprocal|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|ReduceL1|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceL2|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceLogSum|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceLogSumExp|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceMax|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceMean|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceMin|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceProd|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float)| +|ReduceSum|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float), tensor(double)| +|ReduceSumSquare|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(int32), tensor(float), tensor(double)| +|Relu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Reshape|(*in* data:**T**, *in* shape:**tensor(int64)**, *out* reshaped:**T**) or (*in* data:**T**, *out* reshaped:**T**)|5+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**shape** = tensor(int64)| +|Reshape_1||[1, 4]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Resize|(*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**)|10+|**T** = tensor(int32), tensor(float), tensor(uint8)| +|ReverseSequence|(*in* input:**T**, *in* sequence_lens:**tensor(int64)**, *out* Y:**T**)|10+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|RoiAlign|(*in* X:**T1**, *in* rois:**T1**, *in* batch_indices:**T2**, *out* Y:**T1**)|10+|**T** = tensor(float), tensor(double)| +| | ||**T2** = tensor(int64)| +|SVMClassifier|(*in* X:**T1**, *out* Y:**T2**, *out* Z:**tensor(float)**)|1+|**T1** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +| | ||**T2** = tensor(string), tensor(int64)| +|SVMRegressor|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(float)| +|Scale|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|ScaledTanh|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Scaler|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Scan|(*in* sequence_lens:**I**, *in* initial_state_and_scan_inputs:**V**, *out* final_state_and_scan_outputs:**V**) or (*in* initial_state_and_scan_inputs:**V**, *out* final_state_and_scan_outputs:**V**)|9+|**I** = tensor(int64)| +| | ||**V** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[8, 8]|**I** = tensor(int64)| +| | ||**V** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Scatter|(*in* data:**T**, *in* indices:**Tind**, *in* updates:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Selu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Shape|(*in* data:**T**, *out* shape:**T1**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T1** = tensor(int64)| +|Shrink|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sigmoid|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Sign|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sin|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float), tensor(double)| +|Sinh|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float)| +|Size|(*in* data:**T**, *out* size:**T1**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(int64), tensor(double)| +| | ||**T1** = tensor(int64)| +|Slice|(*in* data:**T**, *out* output:**T**) or (*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *in* steps:**Tind**, *out* output:**T**)|10+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +| | |[1, 9]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Softmax|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Softplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Softsign|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|SpaceToDepth|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Split|(*in* input:**T**, *out* outputs:**T**) or (*in* input:**T**, *in* split:**T**, *out* outputs...:**T**)|2+|**T** = tensor(string), tensor(int32), tensor(float)| +|Sqrt|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(double)| +|Squeeze|(*in* data:**T**, *out* squeezed:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|StringNormalizer|(*in* X:**tensor(string)**, *out* Y:**tensor(string)**)|10+|**T** = tensor(string)| +|Sub|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +|Sum|(*in* data_0:**T**, *out* sum:**T**)|8+|**T** = tensor(float)| +| | |[6, 7]|**T** = tensor(float)| +|Tan|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(float)| +|Tanh|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float)| +|TfIdfVectorizer|(*in* X:**T**, *out* Y:**T1**)|9+|**T** = tensor(string), tensor(int32), tensor(int64)| +| | ||**T1** = tensor(float)| +|ThresholdedRelu|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +| | |10+|**T** = tensor(float)| +|Tile|(*in* input:**T**, *in* tiles:**T**, *in* axis:**T**, *out* output:**T**) or (*in* input:**T**, *in* repeats:**T1**, *out* output:**T**)|6+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(int64), tensor(double)| +| | ||**T1** = tensor(int64)| +|TopK|(*in* X:**T**, *in* K:**tensor(int64)**, *out* Values:**T**, *out* Indices:**I**) or (*in* X:**T**, *out* Values:**T**, *out* Indices:**I**)|10+|**I** = tensor(int64)| +| | ||**T** = tensor(float)| +| | |[1, 9]|**I** = tensor(int64)| +| | ||**T** = tensor(float)| +|Transpose|(*in* data:**T**, *out* transposed:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|TreeEnsembleClassifier|(*in* X:**T1**, *out* Y:**T2**, *out* Z:**tensor(float)**)|1+|**T1** = tensor(int32), tensor(float), tensor(int64), tensor(double)| +| | ||**T2** = tensor(string), tensor(int64)| +|TreeEnsembleRegressor|(*in* X:**T**, *out* Y:**tensor(float)**)|1+|**T** = tensor(float)| +|Unsqueeze|(*in* data:**T**, *out* expanded:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Upsample|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**)|[7, 9]|**T** = tensor(int32), tensor(float), tensor(uint8)| +|Where|(*in* condition:**B**, *in* X:**T**, *in* Y:**T**, *out* output:**T**)|9+|**T** = tensor(string), tensor(int32), tensor(float)| +|Xor|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|ZipMap|(*in* X:**tensor(float)**, *out* Z:**T**)|1+|**T** = unknown| +| | +| | +**Operator Domain:** *com.microsoft* +|AttnLSTM|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *in* initial_c:**T**, *in* P:**T**, *in* QW:**T**, *in* MW:**T**, *in* V:**T**, *in* M:**T**, *in* memory_seq_lens:**T1**, *in* AW:**T**, *out* Y:**T**, *out* Y_h:**T**, *out* Y_c:**T**)|1+|**T** = tensor(float), tensor(double)| +| | ||**T1** = tensor(int32)| +|ConvTransposeWithDynamicPads|(*in* X:**T**, *in* W:**T**, *in* Pads:**tensor(int64)**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|CropAndResize|(*in* X:**T1**, *in* rois:**T1**, *in* batch_indices:**T2**, *in* crop_size:**T2**, *out* Y:**T1**)|1+|**T** = tensor(float)| +| | ||**T2** = tensor(int32)| +|ExpandDims|(*in* X:**T**, *in* axis:**tensor(int32)**, *out* Y:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**axis** = tensor(int32)| +|FusedConv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|FusedGemm|(*in* A:**T**, *in* B:**T**, *in* C:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|GatherND|(*in* data:**T**, *in* indices:**Tind**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(string), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|MaxpoolWithMask|(*in* X:**T**, *in* M:**tensor(int32)**, *out* Y:**T**)|1+|**X** = tensor(float)| +|MurmurHash3|(*in* X:**T1**, *out* Y:**T2**)|1+|**T1** = tensor(string), tensor(int32), tensor(uint32)| +| | ||**T2** = tensor(int32), tensor(uint32)| +|Pad|(*in* data:**T**, *in* pads:**tensor(int64)**, *in* value:**T**, *out* output:**T**)|1+|**T** = tensor(float)| +|Range|(*in* start:**T**, *in* limit:**T**, *in* delta:**T**, *out* Y:**T**)|1+|**T** = tensor(int32), tensor(float), tensor(int64), tensor(int16), tensor(double)| +|SampleOp|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Tokenizer|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(string)| +|Unique|(*in* x:**T**, *out* y:**T**, *out* idx:**tensor(int64)**, *out* counts:**tensor(int64)**)|1+|**T** = tensor(float)| +|WordConvEmbedding|(*in* Sequence:**T**, *in* W:**T1**, *in* B:**T1**, *in* C:**T1**, *out* Y:**T1**)|1+|**T** = tensor(int32)| +| | ||**T1** = tensor(float)| +| | +| | +**Operator Domain:** *com.microsoft.nchwc* +|AveragePool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Conv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *in* Sum:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|GlobalAveragePool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|GlobalMaxPool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|MaxPool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|ReorderInput|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|ReorderOutput|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +| | +| | + + +## Operators implemented by CUDAExecutionProvider + +| Op Name | Parameters | OpSet Version | Types Supported | +|---------|------------|---------------|-----------------| +**Operator Domain:** *ai.onnx.ml* +|Abs|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(int32), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Add|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Affine|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|And|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|ArgMax|(*in* data:**T**, *out* reduced:**tensor(int64)**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ArgMin|(*in* data:**T**, *out* reduced:**tensor(int64)**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|AveragePool|(*in* X:**T**, *out* Y:**T**)|10+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[7, 9]|**I** = tensor(int64)| +| | ||**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|9+|**B** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**X** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**mean** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**scale** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**var** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[7, 8]|**B** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**X** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**mean** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**scale** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**var** = tensor(float), tensor(MLFloat16), tensor(double)| +|Cast|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[6, 8]|**T1** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Ceil|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Compress|(*in* input:**T**, *in* condition:**T1**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T1** = tensor(bool)| +|Concat|(*in* inputs:**T**, *out* concat_result:**T**)|4+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|ConstantOfShape|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(int64)| +| | ||**T2** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Conv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ConvTranspose|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Crop|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Div|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Dropout|(*in* data:**T**, *out* output:**T**, *out* mask:**T**) or (*in* data:**T**, *out* output:**T**, *out* mask:**T1**)|10+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(bool)| +| | |[7, 9]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|DynamicSlice|(*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Elu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Equal|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(int32), tensor(bool), tensor(int64)| +|Erf|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Exp|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Expand|(*in* input:**T**, *in* shape:**tensor(int64)**, *out* output:**T**)|8+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Flatten|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[1, 8]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Floor|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|GRU|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(int32)| +|Gather|(*in* data:**T**, *in* indices:**Tind**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Gemm|(*in* A:**T**, *in* B:**T**, *in* C:**T**, *out* Y:**T**)|9+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[7, 8]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|GlobalAveragePool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|GlobalMaxPool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Greater|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|9+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T1** = tensor(bool)| +| | |[7, 8]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|HardSigmoid|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Identity|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|ImageScaler|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|InstanceNormalization|(*in* input:**T**, *in* scale:**T**, *in* B:**T**, *out* output:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|LRN|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|LSTM|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *in* initial_c:**T**, *in* P:**T**, *out* Y:**T**, *out* Y_h:**T**, *out* Y_c:**T**)|7+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(int32)| +|LeakyRelu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Log|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|MatMul|(*in* A:**T**, *in* B:**T**, *out* Y:**T**)|9+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[1, 8]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Max|(*in* data_0:**T**, *out* max:**T**)|8+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[6, 7]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|MaxPool|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *out* Y:**T**, *out* Indices:**I**)|10+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[1, 7]|**I** = tensor(int64)| +| | ||**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[8, 9]|**I** = tensor(int64)| +| | ||**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|MemcpyFromHost|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|MemcpyToHost|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Min|(*in* data_0:**T**, *out* min:**T**)|8+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |[6, 7]|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Mul|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Neg|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(int32), tensor(int16), unknown, tensor(float), tensor(MLFloat16), tensor(int64), tensor(double)| +|Or|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +|PRelu|(*in* X:**T**, *in* slope:**T**, *out* Y:**T**)|7+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Pad|(*in* data:**T**, *out* output:**T**)|2+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ParametricSoftplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Pow|(*in* X:**T**, *in* Y:**T**, *out* Z:**T**)|7+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|RNN|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(int32)| +|Reciprocal|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceL1|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceL2|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceLogSum|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceLogSumExp|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceMax|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceMean|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceMin|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceProd|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceSum|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ReduceSumSquare|(*in* data:**T**, *out* reduced:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Relu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Reshape|(*in* data:**T**, *in* shape:**tensor(int64)**, *out* reshaped:**T**) or (*in* data:**T**, *out* reshaped:**T**)|5+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**shape** = tensor(int64)| +|Reshape_1||[1, 4]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Resize|(*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**)|10+|**T** = tensor(int32), tensor(float), tensor(MLFloat16), tensor(uint8), tensor(double)| +|ScaledTanh|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Selu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Shape|(*in* data:**T**, *out* shape:**T1**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**T1** = tensor(int64)| +|Shrink|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(int32), tensor(int16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sigmoid|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Slice|(*in* data:**T**, *out* output:**T**) or (*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *in* steps:**Tind**, *out* output:**T**)|10+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +| | |[1, 9]|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | ||**Tind** = tensor(int32), tensor(int64)| +|Softmax|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Softplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Softsign|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Split|(*in* input:**T**, *out* outputs:**T**) or (*in* input:**T**, *in* split:**T**, *out* outputs...:**T**)|2+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sqrt|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Squeeze|(*in* data:**T**, *out* squeezed:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sub|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|7+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Sum|(*in* data_0:**T**, *out* sum:**T**)|8+|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +| | |[6, 7]|**T** = tensor(int32), tensor(uint32), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Tanh|(*in* input:**T**, *out* output:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|ThresholdedRelu|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | |10+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Tile|(*in* input:**T**, *in* tiles:**T**, *in* axis:**T**, *out* output:**T**) or (*in* input:**T**, *in* repeats:**T1**, *out* output:**T**)|6+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +| | ||**T1** = tensor(int64)| +|Transpose|(*in* data:**T**, *out* transposed:**T**)|1+|**T** = tensor(float), tensor(MLFloat16), tensor(double)| +|Unsqueeze|(*in* data:**T**, *out* expanded:**T**)|1+|**T** = tensor(int32), tensor(bool), tensor(int16), tensor(bfloat16), tensor(uint8), unknown, tensor(uint32), tensor(uint16), tensor(float), tensor(uint64), tensor(MLFloat16), tensor(int64), tensor(double)| +|Upsample|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**)|[7, 9]|**T** = tensor(int32), tensor(float), tensor(MLFloat16), tensor(uint8), tensor(double)| +|Xor|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)| +| | ||**T1** = tensor(bool)| +| | +| | +**Operator Domain:** *com.microsoft* +|ConvTransposeWithDynamicPads|(*in* X:**T**, *in* W:**T**, *in* Pads:**tensor(int64)**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +| | +| | + + +## Operators implemented by MKLDNNExecutionProvider + +| Op Name | Parameters | OpSet Version | Types Supported | +|---------|------------|---------------|-----------------| +**Operator Domain:** *ai.onnx.ml* +|AveragePool|(*in* X:**T**, *out* Y:**T**)|[7, 8]|**T** = tensor(float)| +|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|7+|**T** = tensor(float)| +|Conv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|Gemm|(*in* A:**T**, *in* B:**T**, *in* C:**T**, *out* Y:**T**)|7+|**T** = tensor(float)| +|GlobalAveragePool|(*in* X:**T**, *out* Y:**T**)|[1, 8]|**T** = tensor(float)| +|GlobalMaxPool|(*in* X:**T**, *out* Y:**T**)|[1, 8]|**T** = tensor(float)| +|LRN|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)| +|MaxPool|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *out* Y:**T**, *out* Indices:**I**)|[1, 7]|**T** = tensor(float)| +| | |[8, 8]|**T** = tensor(float)| +|Relu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)| +|Sum|(*in* data_0:**T**, *out* sum:**T**)|6+|**T** = tensor(float)| +| | +| | diff --git a/docs/Versioning.md b/docs/Versioning.md index d646d777d8335..18a43eb712b05 100644 --- a/docs/Versioning.md +++ b/docs/Versioning.md @@ -45,7 +45,7 @@ A variety of tools can be used to create ONNX models. Unless otherwise noted, pl |Tool|Recommended Version|Supported ONNX version(s)| |---|---|---| -|[PyTorch](https://pytorch.org/)|[Latest stable](https://pytorch.org/get-started/locally/)|1.2-1.5*
*may require [ONNX version converter](https://github.com/onnx/onnx/blob/master/docs/VersionConverter.md) to convert to desired opset #*| +|[PyTorch](https://pytorch.org/)|[Latest stable](https://pytorch.org/get-started/locally/)|1.2-1.5| |[ONNXMLTools](https://pypi.org/project/onnxmltools/)
CoreML, LightGBM, XGBoost, LibSVM|[Latest stable](https://github.com/onnx/onnxmltools/releases)|1.2-1.5| |[ONNXMLTools](https://pypi.org/project/onnxmltools/)
SparkML|[Latest stable](https://github.com/onnx/onnxmltools/releases)|1.4-1.5| |[SKLearn-ONNX](https://pypi.org/project/skl2onnx/)|[Latest stable](https://github.com/onnx/sklearn-onnx/releases)|1.2-1.5| diff --git a/docs/execution_providers/Nuphar-ExecutionProvider.md b/docs/execution_providers/Nuphar-ExecutionProvider.md new file mode 100644 index 0000000000000..a7c859a818e75 --- /dev/null +++ b/docs/execution_providers/Nuphar-ExecutionProvider.md @@ -0,0 +1,142 @@ +## Nuphar Execution Provider (preview) + +NUPHAR stands for Neural-network Unified Preprocessing Heterogeneous ARchitecture. As an execution provider in the ONNX Runtime, it is built on top of [TVM](https://github.com/dmlc/tvm) and [LLVM](https://llvm.org) to accelerate ONNX models by compiling nodes in subgraphs into optimized functions via JIT. It also provides JIT caching to save compilation time at runtime. + +This execution provider release is currently in preview. With the Nuphar execution provider, the ONNX Runtime delivers better inferencing performance on the same hardware compared to generic X64 CPU acceleration, especially for quantized recurrent neural networks. Various products at Microsoft have seen up to a 5x improvement in performance with no loss of accuracy, by running quantized LSTMs via the Nuphar execution provider in the ONNX Runtime. + +### Build Nuphar execution provider +Developers can now tap into the power of Nuphar through ONNX Runtime to accelerate inferencing of ONNX models. Besides, the Nuphar execution provider also comes with a common ONNX to TVM lowering [library](../../onnxruntime/core/codegen), that could be reused by other execution providers to leverage TVM. Instructions to build the Nuphar execution provider from source is available [here](../../BUILD.md#nuphar). + +### Using the Nuphar execution provider +#### C/C++ +The Nuphar execution provider needs to be registered with ONNX Runtime to enable in the inference session. The C API details are [here](../C_API.md#c-api). + +### Python +You can use the Nuphar execution provider via the python wheel from the ONNX Runtime build. The Nuphar execution provider will be automatically prioritized over the default CPU execution providers, thus no need to separately register the execution provider. Python APIs details are [here](../python/api_summary.rst#api-summary). + +### Using onnxruntime_perf_test/onnx_test_runner for performance and accuracy test +You can test your ONNX model's performance with [onnxruntime_perf_test](../../onnxruntime/test/perftest/README.md), or test accuracy with [onnx_test_runner](../../onnxruntime/test/onnx/README.txt). To run these tools with the Nuphar execution provider, please pass `-e nuphar` in command line options. + +### Model conversion/quantization +You may use Python script [model_editor.py](../../onnxruntime/core/providers/nuphar/scripts/model_editor.py) to turn LSTM/GRU/RNN ops to Scan ops for a given model, and then use [model_quantizer.py](../../onnxruntime/core/providers/nuphar/scripts/model_quantizer.py) to quantize MatMul ops into MatMulInteger ops. + +We use dynamic per-row quantization for inputs of LSTM MatMul, so MatMul becomes three parts: quantization, MatMulInteger and dequantization. Weights for MatMulInteger are statically quantized per-column to int8. We have observed good speed-up and no loss of accuracy with this quantization scheme inside Scan for various LSTM models. + +To convert models with LSTM/GRU/RNN ops to Scan ops: +``` +python model_editor.py --input /path/to/input/model --output /path/to/output/model --mode to_scan +``` + +To quantize MatMul ops to MatMulInteger ops (use option --only_for_scan to only quantize MatMuls inside Scan): +``` +python model_quantizer.py --input /path/to/input/model --output /path/to/output/model --only_for_scan +``` + +As an experiment, you may test conversion and quantization on [the BiDAF model](https://github.com/onnx/models/tree/master/bidaf) from the ONNX model zoo. This model has 5 bidirectional LSTM ops, and long sequence lengths. Our test shows that the quantized model has comparable accuracy of F1 76.24, EM 68.08, vs. floating point model accuracy of F1 76.20, EM 68.11. + +Speed-up in this model is ~20% on Intel Xeon E5-1620v4 (Note that AVX2 is required for Nuphar int8 GEMV performance), when comparing CPU execution provider with the floating point model with LSTM ops, vs. the Nuphar execution provider with quantized MatMulInteger inside Scan ops. Profile shows that most of the cost is in input projection outside of Scan ops, which uses MKL SGEMM. It's worth noting that MKL int8 GEMM is about the same speed as SGEMM in this model, so quantization of SGEMMs outside of Scan won't help performance. We are looking at ways to speedup int8 GEMM for better performance on quantized models. + +### JIT caching +You may cache JIT binaries to reduce model loading time spent in JIT, using [create_shared.cmd](../../onnxruntime/core/providers/nuphar/scripts/create_shared.cmd) on Windows with Visual Studio 2017, or [create_shared.sh](../../onnxruntime/core/providers/nuphar/scripts/create_shared.sh) on Linux with gcc. + +Windows +``` +REM You need to have Visual Studio 2017 for compile and link. Optionally, you can save model checksum to the output dll with FCIV tool from https://support.microsoft.com/en-us/help/841290 +set NUPHAR_CACHE_PATH=\path\to\jit\cache +REM Then run Nuphar inference from either onnx_test_runner or onnxruntime_perf_test, or whatever inference using C++ or Python +REM JIT object files would be saved to \path\to\jit\cache\ +create_shared.cmd \path\to\jit\cache\NUPHAR_CACHE_VERSION [optional_model_file_for_checksum] [optional_output_dll_name] +REM If checksum is embedded in dll, set NUPHAR_CACHE_MODEL_CHECKSUM to FCIV output for the model to inference to pass checksum verification at runtime +REM Checksum verification failure will cause Nuphar to fallback to JIT instead of loading binary from cache +REM Run Nuphar inference again with cached JIT dll +``` + +Linux +``` +# You need to have GCC of the same version Nuphar is built with, for compile and link. Optionally, you can save model checksum to jit.so with md5sum +export NUPHAR_CACHE_PATH=/path/to/jit/cache +# Then run Nuphar inference from either onnx_test_runner or onnxruntime_perf_test, or whatever inference using C++ or Python +# JIT object files would be saved to /path/to/jit/cache/ +create_shared.sh -c /path/to/jit/cache/NUPHAR_CACHE_VERSION [-m optional_model_file_for_checksum] [-o optional_output_so_name] +# If checksum is embedded in dll, set NUPHAR_CACHE_MODEL_CHECKSUM to md5sum output for the model to inference to pass checksum verification at runtime +# Checksum verification failure will cause Nuphar to fallback to JIT instead of loading binary from cache +# run Nuphar inference again with cached JIT dll +``` + +### Debugging +There are several [environment variables](../../onnxruntime/core/codegen/common/settings.h) to dump debug information during code generation, plus [some more environment variables](../../onnxruntime/core/providers/nuphar/common/nuphar_settings.h) to dump/control the Nuphar execution provider. You can set environment variables prior to inference to dump debug info to the console. To list some most useful ones: +* CODEGEN_DUMP_LOWER + + Dumps the lowered function from TVM. + + Set it to "verbose" to dump all nodes, or node op_type to dump specific nodes. You may use "concise" to dump just the op_type of nodes. + +* CODEGEN_DUMP_MODULE + + Dumps compiled binary. + + Set it to "ll" to dumps LLVM bit code, "asm" to dumps assembly. + +* CODEGEN_DUMP_SCHEDULE + + Dumps the schedule used in TVM nodes, like compute_root/compute_inline/compute_at. + + Set it to "verbose" to dump all nodes, or node op_type to dump specific nodes. You may use "concise" to dump just the op_type of nodes. + +* NUPHAR_DUMP_PARTITION + + Dumps nodes in each partition. + + Set it to "1" to dump partitions. + +### Settings +When there are conflicts of environment variables running Nuphar in multiple processes, user can specify settings string when creating the Nuphar execution provider. The string comprises of comma separated key:value pairs. Keys should be lower cased environment variable names as shown above, and separated from corresponding values with colon. For example, the equivalent string of setting environment variables of NUPHAR_CACHE_PATH/NUPHAR_CACHE_MODEL_CHECKSUM would be "nuphar_cache_path:, nuphar_cache_model_checksum:". + +* Using in C/C++ + +Settings string could be specified when creating execution provider to specify JIT cache path, as well as model checksum: + +``` +OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, 1, "nuphar_cache_path:/path/to/cache, nuphar_cache_model_checksum:")); +``` + +* Using in C# + +Settings string could be specified when creating session options: + +``` +SessionOptions.MakeSessionOptionWithNupharProvider("nuphar_cache_path:/path/to/cache, nuphar_cache_model_checksum:") +``` + +* Using in Python + +Settings string should be passed in before InferenceSession is created, as providers are not currently exposed yet. Here's an example in Python to set cache path and model checksum: + +``` +nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_model_checksum:{}'.format(cache_dir, model_checksum) +onnxruntime.capi._pybind_state.set_nuphar_settings(nuphar_settings) +sess = onnxruntime.InferenceSession(model_path) +``` + +### Known issues +* ONNX shape inference dependency + + To save runtime JIT cost, Nuphar requires models to have shape inference information from ONNX after model is loaded. Some nodes in ONNX can generate dynamic output tensor shapes from input data value, i.e. ConstantOfShape, Tile, Slice in opset 10, Compress, etc. Those ops may block ONNX shape inference and make the part of graph after such nodes not runnable in Nuphar. + + User may use Python script [symbolic_shape_infer.py](../../onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py) to run symbolic shape inference in ONNX model. This script adds output tensor shapes in the model in graph.value_info field, by doing symbolic dimension computation using sympy when there are Shape ops in model. Besides, running symbolic shape inference on ONNX model would make the graph more readable. Note that when using [model_editor.py](../../onnxruntime/core/providers/nuphar/scripts/model_editor.py) to convert models with LSTM/GRU/RNN to Scan, the resulting model may have incomplete shape inference. Running symbolic_shape_infer.py is needed to get the Scan ops in the model to run in Nuphar. Besides, please note that quantization should be the last step, after verified accuracy and performance of the edited floating point model. + + In addition, user may also manually add shapes to graph.value_info using [onnx.helper.make_tensor_value_info](https://github.com/onnx/onnx/blob/v1.5.0/onnx/helper.py#L290) with model specific knowledge. For example, if you have Hardmax output casted to bool as Compress input condition, then the unknown dimension of the output of Compress is actually 1. + +* Performance benchmark + + Current Nuphar's speed-up in quantized RNNs is optimized for AVX2, when running in single thread and batch size is 1. To help understand RNN performance in different configurations, please use Python script [rnn_benchmark.py](../../onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py). For older X64 CPUs that do not support AVX2, quantized model may have worse performance than non-quantized ones. + +* Patches to TVM + + There are some changes/bug fixes in TVM for Nuphar to work properly. We are in the process of contributing them back to TVM, but for now patches are used in [our forked TVM](https://github.com/microsoft/onnxruntime-tvm). To build cleanly from scratch, please run following commands before running build.bat or build.sh: +``` +git submodule sync +git submodule foreach --recursive git stash +git submodule foreach --recursive git clean -fd +git submodule update --init --recursive +``` \ No newline at end of file diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index 37c4c75ff58fa..a688d4c4cb813 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -1,11 +1,11 @@ -## TensortRT Execution Provider (preview) +## TensortRT Execution Provider -The TensorRT execution provider in the ONNX Runtime will make use of NVIDIA's [TensortRT](https://developer.nvidia.com/tensorrt) Deep Learning inferencing engine to accelerate ONNX model in their family of GPUs. Microsoft and NVIDIA worked closely to integrate the TensorRT execution provider with ONNX Runtime. +The TensorRT execution provider in the ONNX Runtime makes use of NVIDIA's [TensortRT](https://developer.nvidia.com/tensorrt) Deep Learning inferencing engine to accelerate ONNX model in their family of GPUs. Microsoft and NVIDIA worked closely to integrate the TensorRT execution provider with ONNX Runtime. -This execution provider release is currently in preview but, we have validated support for all the ONNX Models in the model zoo. With the TensorRT execution provider, the ONNX Runtime delivers better inferencing performance on the same hardware compared to generic GPU acceleration. +With the TensorRT execution provider, the ONNX Runtime delivers better inferencing performance on the same hardware compared to generic GPU acceleration. ### Build TensorRT execution provider -Developers can now tap into the power of TensorRT through ONNX Runtime to accelerate inferencing of ONNX models. Instructions to build the TensorRT execution provider from source is available [here](https://github.com/Microsoft/onnxruntime/blob/master/BUILD.md#build). +Developers can now tap into the power of TensorRT through ONNX Runtime to accelerate inferencing of ONNX models. Instructions to build the TensorRT execution provider from source are available [here](https://github.com/Microsoft/onnxruntime/blob/master/BUILD.md#build). [Dockerfiles](https://github.com/microsoft/onnxruntime/tree/master/dockerfiles#tensorrt-version-preview) are available for convenience. ### Using the TensorRT execution provider #### C/C++ @@ -18,7 +18,23 @@ status = session_object.Load(model_file_name); The C API details are [here](https://github.com/Microsoft/onnxruntime/blob/master/docs/C_API.md#c-api). ### Python -When using the python wheel from the ONNX Runtime build with TensorRT execution provider, it will be automatically prioritized over the default GPU or CPU execution providers. There is no need to separately register the execution provider. Python APIs details are [here](https://github.com/Microsoft/onnxruntime/blob/master/docs/python/api_summary.rst#api-summary). +When using the Python wheel from the ONNX Runtime build with TensorRT execution provider, it will be automatically prioritized over the default GPU or CPU execution providers. There is no need to separately register the execution provider. Python APIs details are [here](https://microsoft.github.io/onnxruntime/api_summary.html). + +### Performance Tuning +To test the performance of your ONNX Model with the TensorRT execution provider, use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test). + +### Sample +Please see [this Notebook](https://github.com/microsoft/onnxruntime/blob/master/docs/python/notebooks/onnx-inference-byoc-gpu-cpu-aks.ipynb) for an example of running a model on GPU using ONNX Runtime through Azure Machine Learning Services. ### Using onnxruntime_perf_test You can test the performance for your ONNX Model with the TensorRT execution provider. Use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test). + +### Configuring Engine Max Batch Size and Workspace Size +By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB +One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE. +e.g. on Linux +#### override default batch size to 10 +export ORT_TENSORRT_MAX_BATCH_SIZE=10 +#### override default max workspace size to 2GB +export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648 + diff --git a/include/onnxruntime/core/common/callback.h b/include/onnxruntime/core/common/callback.h deleted file mode 100644 index b52288758d7e6..0000000000000 --- a/include/onnxruntime/core/common/callback.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include "core/session/onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct OrtCallback { - void(ORT_API_CALL* f)(void* param) NO_EXCEPTION; - void* param; -} OrtDeleter; - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 8a37553ea976b..d25645123896d 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -72,6 +72,10 @@ struct OrtDevice { DeviceId device_id; }; +inline bool operator==(const OrtDevice& left, const OrtDevice& other) { + return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type(); +} + struct OrtAllocatorInfo { // use string for name, so we could have customized allocator in execution provider. const char* name; @@ -128,6 +132,8 @@ namespace onnxruntime { constexpr const char* CPU = "Cpu"; constexpr const char* CUDA = "Cuda"; constexpr const char* CUDA_PINNED = "CudaPinned"; +constexpr const char* TRT = "Tensorrt"; +constexpr const char* TRT_PINNED = "TensorrtPinned"; // forward declaration class SessionState; diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index ea90436c7c3a9..83248a2721d23 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -205,6 +205,9 @@ class DataTypeImpl { static const std::vector& AllTensorTypes(); static const std::vector& AllFixedSizeTensorTypes(); static const std::vector& AllNumericTensorTypes(); + static const std::vector& AllIEEEFloatTensorTypes(); + static const std::vector& AllFixedSizeTensorExceptHalfTypes(); + static const std::vector& AllIEEEFloatTensorExceptHalfTypes(); }; std::ostream& operator<<(std::ostream& out, MLDataType data_type); diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index 3c093f45401fb..5f783348365bf 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -42,6 +42,12 @@ class KernelDef { *end = op_since_version_end_; } +#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA + const std::pair SinceVersion() const { + return std::pair(op_since_version_start_, op_since_version_end_); + } +#endif + onnxruntime::ProviderType Provider() const { return provider_type_; } diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 3a0d35e298f98..95d9b1d415b92 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -39,6 +39,14 @@ class KernelRegistry { bool IsEmpty() const { return kernel_creator_fn_map_.empty(); } +#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA +// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel + const KernelCreateMap& GetKernelCreateMap() const + { + return kernel_creator_fn_map_; + } +#endif + private: // Check whether the types of inputs/outputs of the given node match the extra // type-constraints of the given kernel. This serves two purposes: first, to diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index e02027a328fcf..6e98dbc20588b 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -210,7 +210,7 @@ struct KernelCreateInfo { : kernel_def(std::move(definition)), kernel_create_func(create_func) {} - KernelCreateInfo(KernelCreateInfo&& other) + KernelCreateInfo(KernelCreateInfo&& other) noexcept : kernel_def(std::move(other.kernel_def)), kernel_create_func(std::move(other.kernel_create_func)) {} }; @@ -231,6 +231,11 @@ template KernelCreateInfo BuildKernelCreateInfo(); } // namespace contrib +namespace automl { +template +KernelCreateInfo BuildKernelCreateInfo(); +} // namespace automl + namespace contrib { namespace cuda { template diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 35eb359c714a3..31a43c7d905cb 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -78,9 +78,9 @@ class Tensor final { //Move is allowed ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); - Tensor(Tensor&& other); + Tensor(Tensor&& other) noexcept; - Tensor& operator=(Tensor&& other); + Tensor& operator=(Tensor&& other) noexcept; /** Returns the data type. diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index acf39638fe0db..c280f61eb1518 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -34,12 +34,13 @@ class TensorShape : private std::vector { TensorShape(TensorShape&& /*other*/) = default; TensorShape& operator=(TensorShape&& /*other*/) = default; - TensorShape(const int64_t* dimension_sizes, size_t dimension_count); + TensorShape(const std::vector& dims) : std::vector(dims) {} + + TensorShape(std::vector&& dims) : std::vector(std::move(dims)) {} - TensorShape(const std::vector& dims); - TensorShape(std::vector&& dims); + TensorShape(const std::initializer_list& dims) : std::vector(dims) {} - TensorShape(const std::initializer_list& dims); + TensorShape(const int64_t* dimension_sizes, size_t dimension_count); TensorShape(const std::vector& dims, size_t start, size_t end); diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 5872228f383d2..6a960e82a3074 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -19,6 +19,7 @@ constexpr const char* kOnnxDomainAlias = "ai.onnx"; constexpr const char* kMLDomain = "ai.onnx.ml"; constexpr const char* kMSDomain = "com.microsoft"; constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; +constexpr const char* kMSAutoMLDomain = "com.microsoft.automl"; constexpr const char* kNGraphDomain = "com.intel.ai"; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index b626a7541713f..1901822011f74 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -25,7 +25,6 @@ namespace onnxruntime { class Graph; struct IndexedSubGraph; -class Node; class OpSignature; /** diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 7e2a0364ed0db..8d6530719dc20 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -38,6 +38,9 @@ class GraphViewer { */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; + /** Returns true if an initializer value can be overridden by a graph input with the same name. */ + bool CanOverrideInitializer() const noexcept; + /** Gets the Graph inputs, excluding initializers. @returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers. @@ -102,9 +105,15 @@ class GraphViewer { return graph_->DomainToVersionMap(); } - /** Check if this is a Subgraph */ + /** Checks if this is a Subgraph */ bool IsSubgraph() const; + /** + returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. + @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name' if not found in 'graph_'. + */ + bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 66952591ce470..3337583612065 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -7,12 +7,27 @@ #include #include +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4267) +#endif +#include +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#else +#pragma warning(pop) +#endif + namespace onnxruntime { namespace concurrency { /** * Generic class for instantiating thread pools. + * Don't put any object of this type into a global variable in a Win32 DLL. */ class ThreadPool { public: @@ -43,14 +58,10 @@ class ThreadPool { int CurrentThreadId() const; - /* - Ensure that the pool has terminated and cleaned up all threads cleanly. - */ - ~ThreadPool(); + Eigen::ThreadPool& GetHandler() { return impl_; } private: - class Impl; - std::unique_ptr impl_; + Eigen::ThreadPool impl_; }; } // namespace concurrency diff --git a/include/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h b/include/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h new file mode 100644 index 0000000000000..58c82a0e1f251 --- /dev/null +++ b/include/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/session/onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif +/** + * \param device_id nuphar device id, starts from zero. + * \param target_str TVM target string. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nuphar, _In_ OrtSessionOptions* options, int allow_unaligned_buffers, _In_ const char* settings_str); + +#ifdef __cplusplus +} +#endif diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h index fb077fc5ff41d..f6f03f80465f4 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h @@ -7,7 +7,7 @@ extern "C" { #endif -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); #ifdef __cplusplus } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6848fc31e453c..899d23181750a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -23,6 +23,10 @@ extern "C" { #define _Inout_ #define _Inout_opt_ #define _Frees_ptr_opt_ +#define _Ret_maybenull_ +#define _Ret_notnull_ +#define _Check_return_ +#define _Success_(X) #define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) #else #include @@ -127,11 +131,11 @@ typedef enum OrtErrorCode { ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION #define ORT_API_STATUS(NAME, ...) \ - ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT // Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT #define ORT_API_STATUS_IMPL(NAME, ...) \ - ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION #define ORT_RUNTIME_CLASS(X) \ struct Ort##X; \ @@ -143,16 +147,13 @@ ORT_RUNTIME_CLASS(Env); ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success ORT_RUNTIME_CLASS(Provider); ORT_RUNTIME_CLASS(AllocatorInfo); -ORT_RUNTIME_CLASS(Session); +ORT_RUNTIME_CLASS(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) ORT_RUNTIME_CLASS(Value); -ORT_RUNTIME_CLASS(ValueList); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(Callback); ORT_RUNTIME_CLASS(CustomOpDomain); -ORT_RUNTIME_CLASS(Allocator); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. @@ -202,6 +203,9 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, */ ORT_API_STATUS(OrtCreateSessionOptions, _Outptr_ OrtSessionOptions** options); +// Set filepath to save optimized model after graph level transformations. +ORT_API_STATUS(OrtSetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath); + // create a copy of an existing OrtSessionOptions ORT_API_STATUS(OrtCloneSessionOptions, _In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options); ORT_API_STATUS(OrtEnableSequentialExecution, _Inout_ OrtSessionOptions* options); @@ -230,15 +234,25 @@ ORT_API_STATUS(OrtSetSessionLogId, _Inout_ OrtSessionOptions* options, const cha // < applies to session load, initialization, etc ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); +ORT_API_STATUS(OrtSetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); // Set Graph optimization level. -// Available options are : 0, 1, 2. -// 0 -> Disable all optimizations -// 1 -> Enable basic optimizations -// 2 -> Enable all optimizations -ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, int graph_optimization_level); - -// How many threads in the session thread pool. +// TODO Add documentation about which optimizations are enabled for each value. +typedef enum GraphOptimizationLevel { + ORT_DISABLE_ALL = 0, + ORT_ENABLE_BASIC = 1, + ORT_ENABLE_EXTENDED = 2, + ORT_ENABLE_ALL = 99 +} GraphOptimizationLevel; +ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, + GraphOptimizationLevel graph_optimization_level); + +/** + * How many threads in the session thread pool. + * Set it to 0 to make onnxruntime run as single threaded. + * \param session_thread_pool_size <0, let the runtime choose a default. =0, Don't create extra threads. + * >0, create a thread pool with size of this value. + */ ORT_API_STATUS(OrtSetSessionThreadPoolSize, _Inout_ OrtSessionOptions* options, int session_thread_pool_size); /** @@ -279,9 +293,11 @@ ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t inde ORT_API_STATUS(OrtCreateRunOptions, _Outptr_ OrtRunOptions** out); ORT_API_STATUS(OrtRunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value); +ORT_API_STATUS(OrtRunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int value); ORT_API_STATUS(OrtRunOptionsSetRunTag, _In_ OrtRunOptions*, _In_ const char* run_tag); ORT_API_STATUS(OrtRunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); +ORT_API_STATUS(OrtRunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); ORT_API_STATUS(OrtRunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const char** out); // Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions @@ -336,35 +352,6 @@ ORT_API_STATUS(OrtGetStringTensorDataLength, _In_ const OrtValue* value, _Out_ s ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _Out_ void* s, size_t s_len, _Out_ size_t* offsets, size_t offsets_len); -/** - * Create an OrtValue in CPU memory from a serialized TensorProto - * @param input serialized TensorProto object - * @param input_len length of 'input'. - * @param input_file_path A local file path of where the input was loaded from. Can be NULL if the tensor proto doesn't - * have any external data or it was loaded from current working dir. This path could be either a - * relative path or an absolute path. - * @param preallocated A preallocated buffer for the tensor. It should be allocated from CPU memory - * @param preallocated_size Length of the preallocated buffer in bytes, can be computed from - * the OrtGetTensorMemSizeInBytesFromTensorProto function. This function will return an error if the - * preallocated_size is not enough. - * @param out - * @return - */ -ORT_API_STATUS(OrtTensorProtoToOrtValue, _In_ const void* input, int input_len, - _In_opt_ const ORTCHAR_T* input_file_path, _Inout_ void* preallocated, size_t preallocated_size, - _Outptr_ OrtValue** out, _Outptr_ OrtCallback** deleter); - -/** - * f will be freed in this call - */ -ORT_API(void, OrtRunCallback, _Frees_ptr_opt_ OrtCallback* f); - -/** - * calculate the memory requirement for the OrtTensorProtoToOrtValue function - */ -ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input, int input_len, size_t alignment, - _Out_ size_t* out); - /** * Don't free the 'out' value */ @@ -461,14 +448,16 @@ ORT_API_STATUS(OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outpt ORT_API_STATUS(OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p); ORT_API_STATUS(OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr, _Out_ const OrtAllocatorInfo** out); -ORT_API_STATUS(OrtCreateDefaultAllocator, _Outptr_ OrtAllocator** out); +// The returned pointer doesn't have to be freed. +// Always returns the same instance on every invocation. +ORT_API_STATUS(OrtGetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); ORT_API(const char*, OrtGetVersionString); /** * \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus */ -ORT_API(OrtStatus*, OrtCreateStatus, OrtErrorCode code, _In_ const char* msg) -ORT_ALL_ARGS_NONNULL; +ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtCreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION + ORT_ALL_ARGS_NONNULL; ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status) ORT_ALL_ARGS_NONNULL; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index e21e87596781e..82263bb071a7a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -43,7 +43,6 @@ struct Exception : std::exception { #define ORT_DEFINE_RELEASE(NAME) \ inline void OrtRelease(Ort##NAME* ptr) { OrtRelease##NAME(ptr); } -ORT_DEFINE_RELEASE(Allocator); ORT_DEFINE_RELEASE(AllocatorInfo); ORT_DEFINE_RELEASE(CustomOpDomain); ORT_DEFINE_RELEASE(Env); @@ -93,7 +92,7 @@ struct Unowned : T { ~Unowned() { this->p_ = nullptr; } }; -struct Allocator; +struct AllocatorWithDefaultOptions; struct AllocatorInfo; struct Env; struct TypeInfo; @@ -120,6 +119,9 @@ struct RunOptions : Base { RunOptions& SetRunLogVerbosityLevel(int); int GetRunLogVerbosityLevel() const; + RunOptions& SetRunLogSeverityLevel(int); + int GetRunLogSeverityLevel() const; + RunOptions& SetRunTag(const char* run_tag); const char* GetRunTag() const; @@ -135,11 +137,13 @@ struct SessionOptions : Base { SessionOptions Clone() const; SessionOptions& SetThreadPoolSize(int session_thread_pool_size); - SessionOptions& SetGraphOptimizationLevel(int graph_optimization_level); + SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); SessionOptions& EnableCpuMemArena(); SessionOptions& DisableCpuMemArena(); + SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); + SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); SessionOptions& DisableProfiling(); @@ -225,16 +229,19 @@ struct Value : Base { TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; }; -struct Allocator : Base { - static Allocator CreateDefault(); +struct AllocatorWithDefaultOptions { + AllocatorWithDefaultOptions(); - explicit Allocator(nullptr_t) {} - explicit Allocator(OrtAllocator* p) : Base{p} {} + operator OrtAllocator*() { return p_; } + operator const OrtAllocator*() const { return p_; } void* Alloc(size_t size); void Free(void* p); const OrtAllocatorInfo* GetInfo() const; + + private: + OrtAllocator* p_{}; }; struct AllocatorInfo : Base { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0fbbbde445b16..3670fdfe71bc6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -39,23 +39,21 @@ struct TypeToTensorType { static constexpr ONNXTensorElementDataType t template <> struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; -inline Allocator Allocator::CreateDefault() { - OrtAllocator* p; - ORT_THROW_ON_ERROR(OrtCreateDefaultAllocator(&p)); - return Allocator(p); +inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { + ORT_THROW_ON_ERROR(OrtGetAllocatorWithDefaultOptions(&p_)); } -inline void* Allocator::Alloc(size_t size) { +inline void* AllocatorWithDefaultOptions::Alloc(size_t size) { void* out; ORT_THROW_ON_ERROR(OrtAllocatorAlloc(p_, size, &out)); return out; } -inline void Allocator::Free(void* p) { +inline void AllocatorWithDefaultOptions::Free(void* p) { ORT_THROW_ON_ERROR(OrtAllocatorFree(p_, p)); } -inline const OrtAllocatorInfo* Allocator::GetInfo() const { +inline const OrtAllocatorInfo* AllocatorWithDefaultOptions::GetInfo() const { const OrtAllocatorInfo* out; ORT_THROW_ON_ERROR(OrtAllocatorGetInfo(p_, &out)); return out; @@ -96,6 +94,11 @@ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { return *this; } +inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { + ORT_THROW_ON_ERROR(OrtRunOptionsSetRunLogSeverityLevel(p_, level)); + return *this; +} + inline int RunOptions::GetRunLogVerbosityLevel() const { int out; ORT_THROW_ON_ERROR(OrtRunOptionsGetRunLogVerbosityLevel(p_, &out)); @@ -138,11 +141,16 @@ inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool return *this; } -inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optimization_level) { +inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level)); return *this; } +inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { + ORT_THROW_ON_ERROR(OrtSetOptimizedModelFilePath(p_, optimized_model_filepath)); + return *this; +} + inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { ORT_THROW_ON_ERROR(OrtEnableProfiling(p_, profile_file_prefix)); return *this; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 29e8f5fb33ebf..4d222e6945916 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -18,4 +18,4 @@ from onnxruntime.capi import onnxruntime_validation onnxruntime_validation.check_distro_info() from onnxruntime.capi.session import InferenceSession -from onnxruntime.capi._pybind_state import RunOptions, SessionOptions, set_default_logger_severity, get_device, NodeArg, ModelMetadata +from onnxruntime.capi._pybind_state import get_device, RunOptions, SessionOptions, NodeArg, ModelMetadata, GraphOptimizationLevel diff --git a/onnxruntime/automl_ops/automl_featurizers.h b/onnxruntime/automl_ops/automl_featurizers.h new file mode 100644 index 0000000000000..37e6e982d9a62 --- /dev/null +++ b/onnxruntime/automl_ops/automl_featurizers.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Cumulative header with automl featurizers includes exposed to +// ORT +#pragma once + +#include "core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h" diff --git a/onnxruntime/automl_ops/automl_types.cc b/onnxruntime/automl_ops/automl_types.cc new file mode 100644 index 0000000000000..8f0cb77701606 --- /dev/null +++ b/onnxruntime/automl_ops/automl_types.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +#include "automl_ops/automl_types.h" +#include "automl_ops/automl_featurizers.h" + +namespace dtf = Microsoft::Featurizer::DateTimeFeaturizer; + +namespace onnxruntime { + +// This temporary to register custom types so ORT is aware of it +// although it still can not serialize such a type. +// These character arrays must be extern so the resulting instantiated template +// is globally unique + +extern const char kMsAutoMLDomain[] = "com.microsoft.automl"; + +extern const char kTimepointName[] = "DateTimeFeaturizer_TimePoint"; +// This has to be under onnxruntime to properly specialize a function template +ORT_REGISTER_OPAQUE_TYPE(dtf::TimePoint, kMsAutoMLDomain, kTimepointName); + +namespace automl { + +#define REGISTER_CUSTOM_PROTO(TYPE, reg_fn) \ + { \ + MLDataType mltype = DataTypeImpl::GetType(); \ + reg_fn(mltype); \ + } + +void RegisterAutoMLTypes(const std::function& reg_fn) { + REGISTER_CUSTOM_PROTO(dtf::TimePoint, reg_fn); +} +#undef REGISTER_CUSTOM_PROTO +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/automl_types.h b/onnxruntime/automl_ops/automl_types.h new file mode 100644 index 0000000000000..798c6778966bb --- /dev/null +++ b/onnxruntime/automl_ops/automl_types.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_types.h" +#include + +namespace onnxruntime { +namespace automl { +void RegisterAutoMLTypes(const std::function& reg_fn); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu/datetime_transformer.cc b/onnxruntime/automl_ops/cpu/datetime_transformer.cc new file mode 100644 index 0000000000000..05a655f8d7453 --- /dev/null +++ b/onnxruntime/automl_ops/cpu/datetime_transformer.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +#include "core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h" + +namespace dtf = Microsoft::Featurizer::DateTimeFeaturizer; + +namespace onnxruntime { +namespace automl { + +class DateTimeTransformer final : public OpKernel { + public: + explicit DateTimeTransformer(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext* context) const override; +}; + +Status DateTimeTransformer::Compute(OpKernelContext* ctx) const { + Status s; + auto input_tensor = ctx->Input(0); + dtf::TimePoint* output = ctx->Output(0); + + int64_t tp = *input_tensor->Data(); + std::chrono::system_clock::time_point sys_time{std::chrono::seconds(tp)}; + *output = dtf::SystemToDPTimePoint(sys_time); + return s; +} + +ONNX_OPERATOR_KERNEL_EX( + DateTimeTransformer, + kMSAutoMLDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetType()), + DateTimeTransformer); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.cc b/onnxruntime/automl_ops/cpu_automl_kernels.cc new file mode 100644 index 0000000000000..23d5e2ad72e6a --- /dev/null +++ b/onnxruntime/automl_ops/cpu_automl_kernels.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "automl_ops/cpu_automl_kernels.h" +#include "core/graph/constants.h" +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace automl { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer); + +void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + // add more kernels here + BuildKernelCreateInfo + }; + + for (auto& function_table_entry : function_table) { + kernel_registry.Register(function_table_entry()); + } +} + +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.h b/onnxruntime/automl_ops/cpu_automl_kernels.h new file mode 100644 index 0000000000000..f14a8983d5a39 --- /dev/null +++ b/onnxruntime/automl_ops/cpu_automl_kernels.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace automl { +void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc index 4555713a59fe1..8757ccb35f771 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc @@ -16,7 +16,7 @@ template AttentionWrapper::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& logger, int batch_size, int attn_context_depth, int attn_layer_depth, int inner_cell_hidden_size, bool has_attn_layer, - const IAttentionMechanism& attention_mechanism) + const IAttentionMechanism& attention_mechanism, concurrency::ThreadPool* threadpool) : allocator_(alloc), logger_(logger), batch_size_(batch_size), @@ -24,7 +24,8 @@ AttentionWrapper::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& attn_layer_depth_(attn_layer_depth), inner_cell_hidden_size_(inner_cell_hidden_size), has_attn_layer_(has_attn_layer), - attention_mechanism_(attention_mechanism) { + attention_mechanism_(attention_mechanism), + ttp_(threadpool) { auto mem_max_steps = attention_mechanism_.GetMaxMemorySteps(); prev_alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, prev_alignments_ptr_, true); alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, alignments_ptr_, true); @@ -37,11 +38,11 @@ template void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_output) { if (has_attn_layer_) { // rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism). - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0}, - rnn_cell_output.data(), inner_cell_hidden_size_, - attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0}, - attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0}, + rnn_cell_output.data(), inner_cell_hidden_size_, + attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0}, + attn_states_.data(), attn_layer_depth_, ttp_); } // Get the context which is calculated within attention mechanism. @@ -54,11 +55,11 @@ void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_outpu //concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) = // p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_ // The first part is calulated above. Here just add the later. - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, - attn_context_.data(), attn_context_depth_, - attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0}, - attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, + attn_context_.data(), attn_context_depth_, + attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0}, + attn_states_.data(), attn_layer_depth_, ttp_); } } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h index 2469a7b99a3fb..b6cc06c040e3a 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h @@ -8,6 +8,7 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/framework/allocator.h" +#include "core/platform/threadpool.h" namespace onnxruntime { namespace contrib { @@ -22,7 +23,7 @@ class AttentionWrapper { int attn_layer_depth, int inner_cell_hidden_size, bool has_attn_layer, - const IAttentionMechanism& attention_mechanism); + const IAttentionMechanism& attention_mechanism, concurrency::ThreadPool* threadpool); virtual ~AttentionWrapper() = default; @@ -69,6 +70,7 @@ class AttentionWrapper { bool has_attn_layer_; const IAttentionMechanism& attention_mechanism_; + concurrency::ThreadPool* ttp_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc index 932ac263f8e22..74ad84b5af839 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc @@ -15,8 +15,8 @@ namespace contrib { template BahdanauAttention::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger, int batch_size, int max_memory_step, int memory_depth, - int query_depth, int attn_depth, bool normalize) - : allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) { + int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool) + : allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool) { values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true); keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true); processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true); @@ -72,11 +72,11 @@ void BahdanauAttention::PrepareMemory( "Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]"); } - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0}, - memory.data(), memory_depth_, - memory_layer_weights_.data(), attn_depth_, T{0.0}, - keys_.data(), attn_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0}, + memory.data(), memory_depth_, + memory_layer_weights_.data(), attn_depth_, T{0.0}, + keys_.data(), attn_depth_, ttp_); } template @@ -115,11 +115,11 @@ void BahdanauAttention::Compute( const gsl::span& output, const gsl::span& aligns) const { //process query in dense query layer without bias - math::GemmEx(CblasNoTrans, CblasNoTrans, - batch_size_, attn_depth_, query_depth_, T{1.0}, - queries.data(), query_depth_, - query_layer_weights_.data(), attn_depth_, T{0.0}, - processed_query_.data(), attn_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + batch_size_, attn_depth_, query_depth_, T{1.0}, + queries.data(), query_depth_, + query_layer_weights_.data(), attn_depth_, T{0.0}, + processed_query_.data(), attn_depth_, ttp_); std::fill(aligns.begin(), aligns.end(), T{}); @@ -146,11 +146,11 @@ void BahdanauAttention::Compute( // Calculate the context auto outspan = output.subspan(b * memory_depth_); auto values = values_.subspan(b * max_memory_steps_ * memory_depth_); - math::GemmEx(CblasNoTrans, CblasNoTrans, - 1, memory_depth_, max_memory_steps_, T{1.0}, - alignments, max_memory_steps_, - values.data(), memory_depth_, T{0.0}, - outspan.data(), memory_depth_, &CPUMathUtil::Instance()); + math::GemmEx(CblasNoTrans, CblasNoTrans, + 1, memory_depth_, max_memory_steps_, T{1.0}, + alignments, max_memory_steps_, + values.data(), memory_depth_, T{0.0}, + outspan.data(), memory_depth_, ttp_); } } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h index 755af6ba6d5c3..c2bfee15c5bcc 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h @@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism { int memory_depth, int query_depth, int attn_depth, - bool normalize); + bool normalize, concurrency::ThreadPool* threadpool); void SetWeights( const gsl::span& attn_weights, @@ -77,6 +77,7 @@ class BahdanauAttention : public IAttentionMechanism { gsl::span mem_seq_lengths_; bool normalize_; + concurrency::ThreadPool* ttp_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc index 7f7102475c620..50e98f834260b 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc @@ -8,7 +8,9 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" +#include "core/platform/threadpool.h" #include "core/framework/allocator.h" +#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { @@ -70,6 +72,9 @@ static gsl::span SecondHalfSpan(const gsl::span& dspan) { template Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { + auto ctx_internal = static_cast(&context); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); + auto& logger = context.Logger(); // original lstm processing @@ -236,7 +241,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { memory_depth, query_depth, am_attn_size, - false); + false, thread_pool); fam.SetWeights( FirstHalfSpan(am_v_weights.DataAsSpan()), @@ -252,7 +257,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { attn_layer_depth, hidden_size_, has_attention_layer, - fam); + fam, thread_pool); faw.SetWeights(FirstHalfSpan(attn_layer_weights_span)); UniDirectionalAttnLstm fw( @@ -263,7 +268,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, thread_pool); BahdanauAttention bam( alloc, @@ -273,7 +278,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { memory_depth, query_depth, am_attn_size, - false); + false, thread_pool); bam.SetWeights( SecondHalfSpan(am_v_weights.DataAsSpan()), SecondHalfSpan(am_query_layer_weights.DataAsSpan()), @@ -288,7 +293,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { attn_layer_depth, hidden_size_, has_attention_layer, - bam); + bam, thread_pool); baw.SetWeights(SecondHalfSpan(attn_layer_weights_span)); UniDirectionalAttnLstm bw( @@ -299,7 +304,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[3], activation_funcs_.Entries()[4], activation_funcs_.Entries()[5], - clip_, ttp_); + clip_, thread_pool); fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2); @@ -313,7 +318,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { memory_depth, query_depth, am_attn_size, - false); + false, thread_pool); fam.SetWeights( am_v_weights.DataAsSpan(), @@ -329,7 +334,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { attn_layer_depth, hidden_size_, has_attention_layer, - fam); + fam, thread_pool); faw.SetWeights(attn_layer_weights_span); @@ -341,7 +346,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, thread_pool); fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); } diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc index caa05f9d5ceff..4183b6e2d6de4 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc @@ -45,7 +45,7 @@ UniDirectionalAttnLstm::UniDirectionalAttnLstm(AllocatorPtr allocator, const ActivationFuncs::Entry& activation_func_g, const ActivationFuncs::Entry& activation_func_h, const float clip, - onnxruntime::concurrency::ThreadPool& ttp) + onnxruntime::concurrency::ThreadPool* ttp) : allocator_(allocator), logger_(logger), seq_length_(seq_length), @@ -254,7 +254,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), // W[iofc]^T input_size_ + attention_size_, T{0.0}, output_iofc_.begin(), output_iofc_.end(), - hidden_size_x4); + hidden_size_x4, ttp_); DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4); @@ -296,7 +296,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin() + input_size_, input_weights.cend(), // WA[iofc] input_size_ + attention_size_, T{1.0}, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, ttp_); // calculate Xt*(W[iofc]^T) + Ht-1*R[iofc] ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0}, @@ -305,7 +305,7 @@ void UniDirectionalAttnLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, T{1.0}, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, ttp_); span_T_iter batched_output, batched_output_end; if (output_sequence) { diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h index 5a8e4e3224a25..2d3a6f20fe1e9 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h @@ -51,7 +51,7 @@ class UniDirectionalAttnLstm { const ActivationFuncs::Entry& activation_func_g, const ActivationFuncs::Entry& activation_func_h, const float clip, - onnxruntime::concurrency::ThreadPool& ttp); + onnxruntime::concurrency::ThreadPool* ttp); void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, @@ -151,7 +151,7 @@ class UniDirectionalAttnLstm { AttentionWrapper& attention_wrapper_; - onnxruntime::concurrency::ThreadPool& ttp_; + onnxruntime::concurrency::ThreadPool* ttp_; }; } // namespace detail diff --git a/onnxruntime/contrib_ops/cpu/matmul_integer16.cc b/onnxruntime/contrib_ops/cpu/matmul_integer16.cc new file mode 100644 index 0000000000000..7378cd56510d5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/matmul_integer16.cc @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/matmul_integer16.h" +#include "core/providers/cpu/math/matmul_helper.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + MatMulInteger16, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMulInteger16); + +template <> +Status MatMulInteger16::Compute(OpKernelContext* ctx) const { + auto A = ctx->Input(0); + auto B = ctx->Input(1); + ORT_ENFORCE(A != nullptr && B != nullptr); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(A->Shape(), B->Shape())); + Tensor* Y = ctx->Output(0, helper.OutputShape()); + + for (int i = 0; i < static_cast(helper.OutputOffsets().size()); i++) { + EigenCastGEMM( + A->template Data() + helper.LeftOffsets()[i], + B->template Data() + helper.RightOffsets()[i], + Y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K())); + } + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/matmul_integer16.h b/onnxruntime/contrib_ops/cpu/matmul_integer16.h new file mode 100644 index 0000000000000..633e8eee52b6a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/matmul_integer16.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace contrib { + +template +class MatMulInteger16 final : public OpKernel { + public: + MatMulInteger16(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index b5625551ad104..3b14b21a79533 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -170,9 +170,6 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind ORT_ENFORCE(X_shape.NumDimensions() == 4); ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); - if (!global_pooling_) { - ORT_RETURN_IF_NOT(kernel_shape_.size() == 2, "kernel_shape num_dims is not compatible with X num_dims."); - } std::vector pads = pads_; std::vector output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_); diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 65045cd0eeb85..b9f8993114094 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -50,6 +50,8 @@ class NchwcConv : public OpKernel, public ConvBase { class NchwcPoolBase : public PoolBase { public: NchwcPoolBase(const OpKernelInfo& info) : PoolBase(info) { + if (!global_pooling_) + ORT_ENFORCE(kernel_shape_.size() == 2, "kernel_shape num_dims is not compatible with X num_dims."); } Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const; diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index 7d7f577d5e3a1..3213ff4fc1db3 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -6,6 +6,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "core/mlas/inc/mlas.h" +#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { @@ -45,7 +46,7 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation( int64_t char_embedding_size, int64_t filter_width, int64_t num_filters, - float* output) const { + float* output, concurrency::ThreadPool* tp) const { int64_t input_word_size = word_len * char_embedding_size; int64_t unfolded_width = word_len - filter_width + 1; int64_t unfolded_kernal_size = filter_width * char_embedding_size; @@ -83,12 +84,12 @@ void WordConvEmbedding::ComputeConvMaxPoolWithActivation( tmp_word_inx++; } - math::GemmEx( + math::GemmEx( CblasNoTrans, CblasTrans, static_cast(words_unfolded_width), static_cast(num_filters), static_cast(unfolded_kernal_size), 1.0f, unfolded_buffer_p.get(), static_cast(unfolded_kernal_size), weights, static_cast(unfolded_kernal_size), 0.0f, - conv_buf_p, static_cast(num_filters), &CPUMathUtil::Instance()); + conv_buf_p, static_cast(num_filters), tp); for (int64_t unfolded_inx = 0; unfolded_inx < words_unfolded_width; unfolded_inx++) for (int64_t filter_inx = 0; filter_inx < num_filters; filter_inx++) { @@ -160,6 +161,9 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co } Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + // original lstm processing const Tensor& sequence = *(ctx->Input(0)); // sequence: [sequence_length, word_length] const Tensor& w_conv = *(ctx->Input(1)); // conv weight: [M, C/group, kH, kW] @@ -216,7 +220,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { char_embedding_size, filter_width, filter_size, - Y->MutableData()); + Y->MutableData(), tp); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h index e74afab169fd8..5ee4127e3bfb9 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h @@ -8,6 +8,9 @@ #include "core/framework/tensor.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} namespace contrib { class WordConvEmbedding final : public OpKernel { @@ -38,7 +41,7 @@ class WordConvEmbedding final : public OpKernel { int64_t char_embedding_size, int64_t filter_width, int64_t num_filters, - float* output) const; + float* output, onnxruntime::concurrency::ThreadPool* tp) const; void CalculateLengthOfEachWordInSequence( const int* seq_ptr, int* words_len_ptr, diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc index 8446a35bd8947..7124360dc6408 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -17,6 +17,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad); @@ -87,6 +88,7 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h new file mode 100644 index 0000000000000..54b737b645da9 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizer.h @@ -0,0 +1,163 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include +#include + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \class Transformer +/// \brief Transforms a single "value" and output the result. +/// A value can be anything from an integer to a collection +/// of integers. +/// +template +class Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Types + using return_type = ReturnT; + using arg_type = ArgT; + using transformer_type = Transformer; + + // ---------------------------------------------------------------------- + // | Public Methods + Transformer(void) = default; + virtual ~Transformer(void) = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + virtual return_type transform(arg_type const &arg) const = 0; + +private: + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &, unsigned int const /*version*/); +}; + +///////////////////////////////////////////////////////////////////////// +/// \class Estimator +/// \brief Collects state over a collection of data, then produces +/// a `Transformer` that is able to operate on that collected +/// state. +/// +template +class Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Types + using transformer_type = Transformer; + using TransformerUniquePtr = std::unique_ptr; + + using estimator_type = Estimator; + + using apache_arrow = unsigned long; // TODO: Temp type as we figure out what will eventually be here + + // ---------------------------------------------------------------------- + // | Public Methods + Estimator(void) = default; + virtual ~Estimator(void) = default; + + Estimator(Estimator const &) = delete; + Estimator & operator =(Estimator const &) = delete; + + Estimator(Estimator &&) = default; + Estimator & operator =(Estimator &&) = delete; + + // This method can be called repeatedly in the support of streaming scenarios + Estimator & fit(apache_arrow const &data); + + // Calls to `commit` are destructive - all previously generated state should + // be reset. `Estimator` objects that want to share state prior to calls to commit + // should implement a `copy` method. + TransformerUniquePtr commit(void); + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool _committed = false; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &, unsigned int const /*version*/); + + virtual Estimator & fit_impl(apache_arrow const &data) = 0; + virtual TransformerUniquePtr commit_impl(void) = 0; +}; + +template +typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args); + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// | +// | Implementation +// | +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- + +// ---------------------------------------------------------------------- +// | +// | Transformer +// | +// ---------------------------------------------------------------------- +template +template +void Transformer::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) { +} + +// ---------------------------------------------------------------------- +// | +// | Estimator +// | +// ---------------------------------------------------------------------- +template +Estimator & Estimator::fit(apache_arrow const &data) { + if(_committed) + throw std::runtime_error("This instance has already been committed"); + + return fit_impl(data); +} + +template +typename Estimator::TransformerUniquePtr Estimator::commit(void) { + if(_committed) + throw std::runtime_error("This instance has already been committed"); + + TransformerUniquePtr result(commit_impl()); + + if(!result) + throw std::runtime_error("Invalid result"); + + _committed = true; + return result; +} + +template +template +void Estimator::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) { +} + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +template +typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args) { + return EstimatorT(std::forward(args)...).fit(data).commit(); +} + +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp new file mode 100644 index 0000000000000..56fc238d86aee --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.cpp @@ -0,0 +1,56 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#include "DateTimeFeaturizer.h" + +#ifdef _MSC_VER +inline struct tm *gmtime_r(time_t const* const timer, struct tm* const result) { + return gmtime_s(result, timer) == 0 ? result : nullptr; +} + +#endif + +namespace Microsoft { +namespace Featurizer { + +namespace DateTimeFeaturizer { + + TimePoint::TimePoint(const std::chrono::system_clock::time_point& sysTime) { + // Get to a tm to get what we need. + // Eventually C++202x will have expanded chrono support that might + // have what we need, but not yet! + std::tm tmt; + time_t tt = std::chrono::system_clock::to_time_t(sysTime); + std::tm* res = gmtime_r(&tt, &tmt); + if (res) { + year = static_cast(tmt.tm_year) + 1900; + month = static_cast(tmt.tm_mon) + 1; + day = static_cast(tmt.tm_mday); + hour = static_cast(tmt.tm_hour); + minute = static_cast(tmt.tm_min); + second = static_cast(tmt.tm_sec); + dayOfWeek = static_cast(tmt.tm_wday); + dayOfYear = static_cast(tmt.tm_yday); + quarterOfYear = (month + 2) / 3; + weekOfMonth = (day - 1) / 7; + } + else + { + if (tt < 0) { + throw std::invalid_argument("Dates prior to 1970 are not supported."); + } + else { + throw std::invalid_argument("Unknown error converting input date."); + } + } + } + + Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ { + return Microsoft::Featurizer::DateTimeFeaturizer::TimePoint(arg); + } + + +} // namespace DateTimeFeaturizer +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h new file mode 100644 index 0000000000000..e1f98351db0b4 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/DateTimeFeaturizer.h @@ -0,0 +1,101 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include "../Featurizer.h" +#include +#include +#include +#include + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \namespace DateTimeTransformer +/// \brief A Transformer that takes a chrono::system_clock::time_point and +/// returns a struct with all the data split out. +/// +namespace DateTimeFeaturizer { + + ///////////////////////////////////////////////////////////////////////// + /// \struct TimePoint + /// \brief Struct to hold various components of DateTime information + /// + struct TimePoint { + std::int32_t year = 0; + std::uint8_t month = 0; /* 1-12 */ + std::uint8_t day = 0; /* 1-31 */ + std::uint8_t hour = 0; /* 0-23 */ + std::uint8_t minute = 0; /* 0-59 */ + std::uint8_t second = 0; /* 0-59 */ + std::uint8_t dayOfWeek = 0; /* 0-6 */ + std::uint16_t dayOfYear = 0; /* 0-365 */ + std::uint8_t quarterOfYear = 0; /* 1-4 */ + std::uint8_t weekOfMonth = 0; /* 0-4 */ + + // Need default __ctor to satisfy ORT type system + TimePoint() = default; + TimePoint(const std::chrono::system_clock::time_point& sysTime); + + TimePoint(TimePoint&&) = default; + TimePoint& operator=(TimePoint&&) = default; + + TimePoint(const TimePoint&) = delete; + TimePoint& operator=(const TimePoint&) = delete; + + bool operator==(const TimePoint& o) const { + return year == o.year && + month == o.month && + day == o.day && + hour == o.hour && + minute == o.minute && + second == o.second && + dayOfWeek == o.dayOfWeek && + dayOfYear == o.dayOfYear && + quarterOfYear == o.quarterOfYear && + weekOfMonth == o.weekOfMonth; + } + + enum { + JANUARY = 1, FEBRUARY, MARCH, APRIL, MAY, JUNE, + JULY, AUGUST, SEPTEMBER, OCTOBER, NOVEMBER, DECEMBER + }; + enum { + SUNDAY = 0, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY + }; + }; + + inline TimePoint SystemToDPTimePoint(const std::chrono::system_clock::time_point& sysTime) { + return TimePoint (sysTime); + } + + ///////////////////////////////////////////////////////////////////////// + /// \class DateTimeTransformer + /// \brief Transformer + /// + class Transformer : public Microsoft::Featurizer::Transformer { + public: + Transformer(void) = default; + ~Transformer(void) override = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + return_type transform(arg_type const &arg) const override; + + private: + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); + }; + +} // Namespace DateTimeFeaturizer +} // Namespace Featurizer +} // Namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp new file mode 100644 index 0000000000000..b474ce3bd8a62 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.cpp @@ -0,0 +1,40 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#include "SampleAdd.h" + +namespace Microsoft { +namespace Featurizer { +namespace SampleAdd { + +// ---------------------------------------------------------------------- +// | +// | Transformer +// | +// ---------------------------------------------------------------------- +Transformer::Transformer(std::uint16_t delta) : + _delta(delta) { +} + +Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ { + return _delta + arg; +} + +// ---------------------------------------------------------------------- +// | +// | Estimator +// | +// ---------------------------------------------------------------------- +Estimator & Estimator::fit_impl(apache_arrow const &data) /*override*/ { + _accumulated_delta += static_cast(data); + return *this; +} + +Estimator::TransformerUniquePtr Estimator::commit_impl(void) /*override*/ { + return std::make_unique(static_cast(_accumulated_delta)); +} + +} // namespace SampleAdd +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h new file mode 100644 index 0000000000000..f4ca7601e5dd0 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/SampleAdd.h @@ -0,0 +1,95 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#pragma once + +#include "../Featurizer.h" + +namespace Microsoft { +namespace Featurizer { + +///////////////////////////////////////////////////////////////////////// +/// \namespace SampleAdd +/// \brief A Transformer and Estimator that add values. This is a +/// sample intended to demonstrate patterns within the +/// implementation of these types. +/// +namespace SampleAdd { + +///////////////////////////////////////////////////////////////////////// +/// \class Transformer +/// \brief Transformer that adds an integer value to a saved delta +/// and returns the result. +/// +class Transformer : public Microsoft::Featurizer::Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Methods + Transformer(std::uint16_t delta=0); + ~Transformer(void) override = default; + + Transformer(Transformer const &) = delete; + Transformer & operator =(Transformer const &) = delete; + + Transformer(Transformer &&) = default; + Transformer & operator =(Transformer &&) = delete; + + return_type transform(arg_type const &arg) const override; + +private: + // ---------------------------------------------------------------------- + // | Private Data + std::uint32_t const _delta; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); +}; + +///////////////////////////////////////////////////////////////////////// +/// \class Estimator +/// \brief Estimator that accumulates a delta value and then +/// creates a Transformer with than value when requested. +/// +class Estimator : public Microsoft::Featurizer::Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Methods + Estimator(void) = default; + ~Estimator(void) override = default; + + Estimator(Estimator const &) = delete; + Estimator & operator =(Estimator const &) = delete; + + Estimator(Estimator &&) = default; + Estimator & operator =(Estimator &&) = delete; + +private: + // ---------------------------------------------------------------------- + // | Private Data + std::uint32_t _accumulated_delta = 0; + + // ---------------------------------------------------------------------- + // | Private Methods + template + void serialize(ArchiveT &ar, unsigned int const version); + + Estimator & fit_impl(apache_arrow const &data) override; + TransformerUniquePtr commit_impl(void) override; +}; + +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// | +// | Implementation +// | +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +// ---------------------------------------------------------------------- +} // namespace SampleAdd + +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt new file mode 100644 index 0000000000000..acbc320062979 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/CMakeLists.txt @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License +# ---------------------------------------------------------------------- +cmake_minimum_required(VERSION 3.5.0) + +project(Featurizer_UnitTests LANGUAGES CXX) + +set(CMAKE_MODULE_PATH "$ENV{DEVELOPMENT_ENVIRONMENT_CMAKE_MODULE_PATH}") + +if(NOT WIN32) + string(REPLACE ":" ";" CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + string(REPLACE ":" ";" _includes "$ENV{INCLUDE}") + string(REPLACE ":" ";" _libs "$ENV{LIB}") +endif() + +set(CppCommon_STATIC_CRT ON CACHE BOOL "" FORCE) +set(BoostCommon_HEADER_ONLY ON CACHE BOOL "" FORCE) + +include(CppCommon) +include(BoostCommon) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +add_library(libFeaturizers STATIC + ../SampleAdd.h + ../SampleAdd.cpp + ../DateTimeFeaturizer.h + ../DateTimeFeaturizer.cpp +) + +enable_testing() + +foreach(_test_name IN ITEMS + SampleAdd_UnitTest + DateTimeFeaturizer_UnitTests +) + add_executable(${_test_name} ${_test_name}.cpp) + + target_include_directories(${_test_name} PRIVATE ${_includes}) + target_link_directories(${_test_name} PRIVATE ${_libs}) + + target_link_libraries(${_test_name} PRIVATE ${Boost_LIBRARIES} libFeaturizers) + + add_test(NAME ${_test_name} COMMAND ${_test_name} --success) +endforeach() diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp new file mode 100644 index 0000000000000..d81bb22964dbe --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/DateTimeFeaturizer_UnitTests.cpp @@ -0,0 +1,125 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include +#include "gtest/gtest.h" + +#include "../DateTimeFeaturizer.h" + + +namespace Microsoft { +namespace Featurizer { +namespace DateTimeFeaturizer { + +using SysClock = std::chrono::system_clock; + +TEST(DateTimeFeaturizer_DateTime, Past_1976_Nov_17__12_27_04) { + const time_t date = 217081624; + SysClock::time_point stp = SysClock::from_time_t(date); + + // Constructor + TimePoint tp(stp); + ASSERT_EQ(tp.year, 1976); + ASSERT_EQ(tp.month, TimePoint::NOVEMBER); + ASSERT_EQ(tp.day, 17); + ASSERT_EQ(tp.hour, 12); + ASSERT_EQ(tp.minute, 27); + ASSERT_EQ(tp.second, 4); + ASSERT_EQ(tp.dayOfWeek, TimePoint::WEDNESDAY); + ASSERT_EQ(tp.dayOfYear, 321); + ASSERT_EQ(tp.quarterOfYear, 4); + ASSERT_EQ(tp.weekOfMonth, 2); + + // assignment + TimePoint tp1 = stp; + ASSERT_EQ(tp1.year, 1976); + ASSERT_EQ(tp1.month, TimePoint::NOVEMBER); + ASSERT_EQ(tp1.day, 17); + + // function + TimePoint tp2 = SystemToDPTimePoint(stp); + ASSERT_EQ(tp2.year, 1976); + ASSERT_EQ(tp2.month, TimePoint::NOVEMBER); + ASSERT_EQ(tp2.day, 17); +} + +TEST(DateTimeFeaturizer_Transformer , Past_1976_Nov_17__12_27_05) { + const time_t date = 217081625; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_EQ(tp.year, 1976); + ASSERT_EQ(tp.month, TimePoint::NOVEMBER); + ASSERT_EQ(tp.day, 17); + ASSERT_EQ(tp.hour, 12); + ASSERT_EQ(tp.minute, 27); + ASSERT_EQ(tp.second, 5); + ASSERT_EQ(tp.dayOfWeek, TimePoint::WEDNESDAY); + ASSERT_EQ(tp.dayOfYear, 321); + ASSERT_EQ(tp.quarterOfYear, 4); + ASSERT_EQ(tp.weekOfMonth, 2); + +} + +TEST(DateTimeFeaturizer_Transformer , Future_2025_June_30) { + const time_t date = 1751241600; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_EQ(tp.year, 2025); + ASSERT_EQ(tp.month, TimePoint::JUNE); + ASSERT_EQ(tp.day, 30); + ASSERT_EQ(tp.hour, 0); + ASSERT_EQ(tp.minute, 0); + ASSERT_EQ(tp.second, 0); + ASSERT_EQ(tp.dayOfWeek, TimePoint::MONDAY); + ASSERT_EQ(tp.dayOfYear, 180); + ASSERT_EQ(tp.quarterOfYear, 2); + ASSERT_EQ(tp.weekOfMonth, 4); +} + +#ifdef _MSC_VER +// others define system_clock::time_point as nanoseconds (64-bit), +// which rolls over somewhere around 2260. Still a couple hundred years! +TEST(DateTimeFeaturizer_Transformer , Far_Future__2998_March_2__14_03_02) { + const time_t date = 32445842582; + SysClock::time_point stp = SysClock::from_time_t(date); + + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_EQ(tp.year, 2998); + ASSERT_EQ(tp.month, TimePoint::MARCH); + ASSERT_EQ(tp.day, 2); + ASSERT_EQ(tp.hour, 14); + ASSERT_EQ(tp.minute, 3); + ASSERT_EQ(tp.second, 2); + ASSERT_EQ(tp.dayOfWeek, TimePoint::FRIDAY); + ASSERT_EQ(tp.dayOfYear, 60); + ASSERT_EQ(tp.quarterOfYear, 1); + ASSERT_EQ(tp.weekOfMonth, 0); +} + +#else + +// msvcrt doesn't support negative time_t, so nothing before 1970 +TEST(DateTimeFeaturizer_Transformer, Pre_Epoch__1776_July_4) { + + const time_t date = -6106060800; + SysClock::time_point stp = SysClock::from_time_t(date); + + // Constructor + Transformer dt; + TimePoint tp = dt.transform(stp); + ASSERT_EQ(tp.year, 1776); + ASSERT_EQ(tp.month, TimePoint::JULY); + ASSERT_EQ(tp.day, 4); +} +#endif /* _MSC_VER */ +} // namespace DateTimeFeaturizer +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp new file mode 100644 index 0000000000000..b3796ec3c4d62 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/SampleAdd_UnitTest.cpp @@ -0,0 +1,22 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include "gtest/gtest.h" + +#include "../SampleAdd.h" + +TEST(SampleAddTests, Transformer) { + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Transformer(10).transform(20), 30U); + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Transformer(20).transform(1), 21U); +} + +TEST(SampleAddTests, Estimator) { + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).commit()->transform(20), 30U); + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Estimator().fit(20).commit()->transform(1), 21U); + + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).commit()->transform(20), 50U); + ASSERT_EQ(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).fit(30).commit()->transform(20), 80U); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml new file mode 100644 index 0000000000000..e3f068978a9bd --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Featurizers/UnitTests/code_coverage.yaml @@ -0,0 +1,5 @@ +filter: + includes: + - Microsoft::Featurizer::* + excludes: + - std::* diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h new file mode 100644 index 0000000000000..37a70a059d14a --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/Traits.h @@ -0,0 +1,218 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#pragma once +#include +#include +#include +#include +#include + +namespace Microsoft { +namespace Featurizer { +namespace Traits { + +// XXX: Define the type +template +struct Nullable {}; + +///////////////////////////////////////////////////////////////////////// +/// \namespace Traits +/// \brief We have a range of of types we are dealing with. Many types +/// have different ways to represent what a `NULL` value is +/// (float has NAN for example) as well as different ways to +/// convert the value to a string representation. By using +/// templates combined with partial template specialization +/// we can handle scenarios like these that vary based on the data type. +/// +/// Example: This allows us to do things like `Traits::IsNull()` +/// and `Traits::IsNull()` and let the trait itself deal with the +/// actual implementation and allows us as developers to not worry about that. +/// +/// This benefit is magnified because we are also using templates for our +/// transformers. When we declare that a transformer has type T = std::int8_t, +/// we can then also use `Traits::IsNull()` and the compiler will know that +/// `T` is a `std::int8_t` and call the appropate template specialization. +/// +template +struct Traits {}; + +///////////////////////////////////////////////////////////////////////// +/// \namespace Traits +/// \brief When using partial template specilization, if the compiler +/// cannot find a more specfic implementation of the template +/// it will fall back to the base template and use whatever is +/// defined there. If you have methods defined in that base template, +/// it makes it very difficult to debug what is going on. By +/// putting no implementation in the `Traits<>` template and +/// having the real base struct be `TraitsImpl<>`, if you try and +/// specify a trait that doesn't have a specilization, the compiler +/// can detect that and throw an error during compilation. +/// +/// Example: There is no template `Traits`. If you try and use it +/// the compiler will fall back to the `Traits<>` struct which has no methods +/// defined. Trying to then use `Traits` will cause a compile time error +/// letting you know something isn't correct. +/// +template +struct TraitsImpl { + using nullable_type = Nullable; + static bool IsNull(nullable_type const& value) { + return !value.is_initialized(); + } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = float; + static bool IsNull(nullable_type const& value) { + return std::isnan(value); + } + + // static std::string ToString(nullable_type const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = double; + static bool IsNull(nullable_type const& value) { + return std::isnan(value); + } + + // static std::string ToString(nullable_type const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int8_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int16_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int32_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::int64_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint8_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + using nullable_type = Nullable; + // static std::string ToString(std::uint16_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint32_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::uint64_t const& value) { + // return std::to_string(value); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(std::string const& value) { + // value; + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::array const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template <> +struct Traits : public TraitsImpl { + // static std::string ToString(bool const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::map const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::vector const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::function const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +template +struct Traits> : public TraitsImpl> { + using nullable_type = Nullable; + + // static std::string ToString(nullable_type const& value) { + // if (value) { + // return Traits::ToString(value.get()); + // } + + // return "NULL"; + // } +}; + +template +struct Traits> : public TraitsImpl> { + // static std::string ToString(std::tuple const& value) { + // // Decide what to return here + // throw std::logic_error("Function not yet implemented"); + // } +}; + +} // namespace Traits +} // namespace Featurizer +} // namespace Microsoft diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt new file mode 100644 index 0000000000000..024c76f3443a7 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/CMakeLists.txt @@ -0,0 +1,41 @@ +# ---------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License +# ---------------------------------------------------------------------- +cmake_minimum_required(VERSION 3.5.0) + +project(Featurizer_UnitTests LANGUAGES CXX) + +set(CMAKE_MODULE_PATH "$ENV{DEVELOPMENT_ENVIRONMENT_CMAKE_MODULE_PATH}") + +if(NOT WIN32) + string(REPLACE ":" ";" CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + string(REPLACE ":" ";" _includes "$ENV{INCLUDE}") + string(REPLACE ":" ";" _libs "$ENV{LIB}") +endif() + +set(CppCommon_STATIC_CRT ON CACHE BOOL "" FORCE) +set(BoostCommon_HEADER_ONLY ON CACHE BOOL "" FORCE) + +include(CppCommon) +include(BoostCommon) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +enable_testing() + +foreach(_test_name IN ITEMS + Featurizer_UnitTest + Traits_UnitTests +) + add_executable(${_test_name} ${_test_name}.cpp) + + target_include_directories(${_test_name} PRIVATE ${_includes}) + target_link_directories(${_test_name} PRIVATE ${_libs}) + + target_link_libraries(${_test_name} PRIVATE ${Boost_LIBRARIES}) + + add_test(NAME ${_test_name} COMMAND ${_test_name} --success) +endforeach() diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp new file mode 100644 index 0000000000000..c0340e738c1c4 --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Featurizer_UnitTest.cpp @@ -0,0 +1,104 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- + +#define CATCH_CONFIG_MAIN +#include "gtest/gtest.h" +#include "../Featurizer.h" + +class MyTransformer : public Microsoft::Featurizer::Transformer { +public: + // ---------------------------------------------------------------------- + // | Public Methods + MyTransformer(bool true_on_odd=false) : + _true_on_odd(true_on_odd) { + } + + ~MyTransformer(void) override = default; + + MyTransformer(MyTransformer const &) = delete; + MyTransformer & operator =(MyTransformer const &) = delete; + + MyTransformer(MyTransformer &&) = default; + MyTransformer & operator =(MyTransformer &&) = delete; + + return_type transform(arg_type const &arg) const override { + bool const is_odd(arg & 1); + + return _true_on_odd ? is_odd : !is_odd; + } + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool const _true_on_odd; +}; + +class MyEstimator : public Microsoft::Featurizer::Estimator { +public: + // ---------------------------------------------------------------------- + // | Public Methods + MyEstimator(bool return_invalid_transformer=false) : + _return_invalid_transformer(return_invalid_transformer) { + } + + ~MyEstimator(void) override = default; + + MyEstimator(MyEstimator const &) = delete; + MyEstimator & operator =(MyEstimator const &) = delete; + + MyEstimator(MyEstimator &&) = default; + MyEstimator & operator =(MyEstimator &&) = delete; + +private: + // ---------------------------------------------------------------------- + // | Private Data + bool const _return_invalid_transformer; + bool _true_on_odd_state; + + // ---------------------------------------------------------------------- + // | Private Methods + MyEstimator & fit_impl(apache_arrow const &data) override { + _true_on_odd_state = static_cast(data); + return *this; + } + + TransformerUniquePtr commit_impl(void) override { + if(_return_invalid_transformer) + return TransformerUniquePtr(); + + return std::make_unique(_true_on_odd_state); + } +}; + +TEST(FeaturizerTests, TransformerFunctionality) { + ASSERT_TRUE(MyTransformer(true).transform(1)); + ASSERT_FALSE(MyTransformer(false).transform(1)); + ASSERT_FALSE(MyTransformer(true).transform(2)); + ASSERT_TRUE(MyTransformer(false).transform(2)); +} + +TEST(FeaturizerTests, EstimatorFunctionality) { + ASSERT_TRUE(MyEstimator().fit(1).commit()->transform(1)); + ASSERT_FALSE(MyEstimator().fit(0).commit()->transform(1)); + ASSERT_FALSE(MyEstimator().fit(1).commit()->transform(2)); + ASSERT_TRUE(MyEstimator().fit(0).commit()->transform(2)); +} + +TEST(FeaturizerTests, EstimatorErrors) { + MyEstimator e; + + ASSERT_NE(e.commit(), nullptr); + //CHECK_THROWS_WITH(e.fit(1), Catch::Contains("has already been committed")); + //CHECK_THROWS_WITH(e.commit(), Catch::Contains("has already been committed")); + + //CHECK_THROWS_WITH(MyEstimator(true).commit(), Catch::Matches("Invalid result")); +} + +TEST(FeaturizerTests, EstimatorFitAndCommit) { + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(1, false)->transform(1)); + ASSERT_FALSE(Microsoft::Featurizer::fit_and_commit(0, false)->transform(1)); + ASSERT_FALSE(Microsoft::Featurizer::fit_and_commit(1, false)->transform(2)); + ASSERT_TRUE(Microsoft::Featurizer::fit_and_commit(0, false)->transform(2)); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp new file mode 100644 index 0000000000000..66589a5c9decc --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/Traits_UnitTests.cpp @@ -0,0 +1,40 @@ +// ---------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License +// ---------------------------------------------------------------------- +#define CATCH_CONFIG_MAIN +#include +#include "gtest/gtest.h" + +#include "../Traits.h" + +using namespace Microsoft::Featurizer::Traits; + +// Floating point values +static_assert(std::is_same::nullable_type, float>::value, "Incorrect nullable type for float"); +static_assert(std::is_same::nullable_type, double>::value, "Incorrect nullable type for double"); + +// Int values +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int8_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int16_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int32_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::int64_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint8_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint16_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint32_t"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::uint64_t"); + +// Others +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::array"); +static_assert(std::is_same::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>::value, "Incorrect nullable type for std::string"); +static_assert(std::is_same>::nullable_type, Nullable>>::value, "Incorrect nullable type for std::string"); + +// Dummy test so it will compile. Replace this with actual tests. +TEST(TraitsTests, Dummy) { + ASSERT_TRUE(true); +} diff --git a/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp new file mode 100644 index 0000000000000..b6a004002b83c --- /dev/null +++ b/onnxruntime/core/automl/featurizers/src/FeaturizerPrep/UnitTests/test_main.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" + +GTEST_API_ int main(int argc, char** argv) { + int status = 0; + + testing::InitGoogleTest(&argc, argv); + try { + status = RUN_ALL_TESTS(); + } catch (const std::exception& ex) { + std::cerr << ex.what(); + status = -1; + } + + return status; +} diff --git a/onnxruntime/core/codegen/common/common.cc b/onnxruntime/core/codegen/common/common.cc index 757c1677dd2e5..f7a774609669c 100644 --- a/onnxruntime/core/codegen/common/common.cc +++ b/onnxruntime/core/codegen/common/common.cc @@ -120,6 +120,11 @@ std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& gr meta_def->name += "_With" + std::to_string(subgraph->nodes.size()) + "Nodes_"; meta_def->name += end_node.OpType() + std::to_string(end_node_index); + std::unordered_set real_output_names; + for (const auto* def : graph.GetOutputs()) { + real_output_names.insert(def->Name()); + } + for (const auto& node_index : subgraph->nodes) { const auto& node = *graph.GetNode(node_index); // handle current graph's inputs @@ -140,6 +145,7 @@ std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& gr // 1. Output NodeArg is not used by any Node // 2. Output NodeArg is used by at least one Node out of this subgraph. // Note a NodeArg can be used by Nodes in and out of the subgraph at the same time. + // 3. Output NodeArg is one of real outputs of an Ort graph. auto InsertOutputToSubgraph = [&meta_def](const NodeArg* def) { if (std::find(meta_def->outputs.begin(), meta_def->outputs.end(), def->Name()) == @@ -169,11 +175,12 @@ std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& gr } } - // handle case 1 + // handle case 1 and 3 node.ForEachWithIndex( node.OutputDefs(), [&](const onnxruntime::NodeArg& def, size_t) { - if (input_names_from_the_output_node.count(def.Name()) == 0) { + if (input_names_from_the_output_node.count(def.Name()) == 0 || + real_output_names.count(def.Name()) > 0) { InsertOutputToSubgraph(&def); } return Status::OK(); diff --git a/onnxruntime/core/codegen/common/creator.h b/onnxruntime/core/codegen/common/creator.h index d15e86b5a481f..b31a12db4875b 100644 --- a/onnxruntime/core/codegen/common/creator.h +++ b/onnxruntime/core/codegen/common/creator.h @@ -25,7 +25,7 @@ class CreatorBase { CreatorBase(const std::string& name) : name_(name) {} - ~CreatorBase() = default; + virtual ~CreatorBase() = default; virtual RETURN_TYPE Evaluate(INPUT_TYPE, NODE_TYPE, diff --git a/onnxruntime/core/codegen/common/dispatcher.h b/onnxruntime/core/codegen/common/dispatcher.h index b4313cecad3a8..80a854a06977c 100644 --- a/onnxruntime/core/codegen/common/dispatcher.h +++ b/onnxruntime/core/codegen/common/dispatcher.h @@ -16,6 +16,7 @@ namespace codegen { // 2) dump corresponding name // DispatcherBase may or may not keep ownership, // depending on the template parameter, CONTENT_TYPE. +// Note DispatcherBase has a protected destructor template class DispatcherBase { @@ -68,6 +69,7 @@ class DispatcherBase { std::string name_; std::unordered_map contents_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DispatcherBase); + ~DispatcherBase() = default; }; } // namespace codegen diff --git a/onnxruntime/core/codegen/common/profile.h b/onnxruntime/core/codegen/common/profile.h index 642ae83db723b..d9e5a9725e9e4 100644 --- a/onnxruntime/core/codegen/common/profile.h +++ b/onnxruntime/core/codegen/common/profile.h @@ -28,7 +28,7 @@ class ProfilerEvent { } // namespace onnxruntime -#define CODEGEN_PROFILER_EVENT(name) onnxruntime::ProfilerEvent name##_profiler_event(#name) +#define CODEGEN_PROFILER_EVENT(name) onnxruntime::ProfilerEvent profiler_event(name) #else diff --git a/onnxruntime/core/codegen/common/registry.h b/onnxruntime/core/codegen/common/registry.h index 1ec06d4d8d96c..c1642e76e2120 100644 --- a/onnxruntime/core/codegen/common/registry.h +++ b/onnxruntime/core/codegen/common/registry.h @@ -21,6 +21,8 @@ class RegistryBase { public: RegistryBase() = default; + virtual ~RegistryBase() = default; + bool Contains(const std::string& name) const { return contents_.count(name) > 0; } diff --git a/onnxruntime/core/codegen/common/settings.cc b/onnxruntime/core/codegen/common/settings.cc index c046f2892088d..529cb654f922c 100644 --- a/onnxruntime/core/codegen/common/settings.cc +++ b/onnxruntime/core/codegen/common/settings.cc @@ -70,5 +70,9 @@ bool CodeGenSettings::OptionMatches(const std::string& key, const std::string& v #endif } +void CodeGenSettings::Clear() { + options_.clear(); +} + } // namespace codegen } // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/settings.h b/onnxruntime/core/codegen/common/settings.h index 95a2282ccb1ff..4bce9a614b7e1 100644 --- a/onnxruntime/core/codegen/common/settings.h +++ b/onnxruntime/core/codegen/common/settings.h @@ -26,6 +26,7 @@ class CodeGenSettings { std::string GetOptionValue(const std::string& key) const; bool HasOption(const std::string& key) const; bool OptionMatches(const std::string& key, const std::string& value) const; + void Clear(); static CodeGenSettings& Instance(); private: diff --git a/onnxruntime/core/codegen/mti/math/gemm.cc b/onnxruntime/core/codegen/mti/math/gemm.cc index b5e5da5301775..7a79513ccaa97 100644 --- a/onnxruntime/core/codegen/mti/math/gemm.cc +++ b/onnxruntime/core/codegen/mti/math/gemm.cc @@ -17,10 +17,12 @@ tvm::Tensor Gemm(const tvm::Tensor& A, const tvm::Tensor& B, const tvm::Tensor& bool trans_A, bool trans_B, float alpha, float beta, const std::string& name) { auto A_dot_B = MatMul2D(A, B, trans_A, trans_B, name + "_matmul2d"); + tvm::Expr alphaExpr = tvm::make_const(A->dtype, alpha); if (beta != 0) { - return Rename(alpha * A_dot_B + (beta * C), name); + tvm::Expr betaExpr = tvm::make_const(A->dtype, beta); + return Rename(alphaExpr * A_dot_B + (betaExpr * C), name); } else { - return Rename(alpha * A_dot_B, name); + return Rename(alphaExpr * A_dot_B, name); } } diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.cc b/onnxruntime/core/codegen/mti/math/matmul_ops.cc index 672aa3a6cf8db..46f2fb75b6e24 100644 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.cc +++ b/onnxruntime/core/codegen/mti/math/matmul_ops.cc @@ -117,22 +117,31 @@ tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string return tvm::sum(A(a_indices) * B(b_indices), {k}); }; - tvm::Array output_shape; - int64_t output_rank = std::max(a_rank, b_rank); - MTI_ASSERT(tvm::ir::Equal(A_shape[a_rank - 1], B_shape[b_rank - 2])); - for (int64_t i = 0; i < output_rank - 2; i++) { - tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1); - bool broadcasted = - BroadcastDim(A_shape, i, output_rank, broadcasted_dim) && - BroadcastDim(B_shape, i, output_rank, broadcasted_dim); - MTI_ASSERT(broadcasted); - output_shape.push_back(broadcasted_dim); - } - output_shape.push_back(A_shape[a_rank - 2]); - output_shape.push_back(B_shape[b_rank - 1]); - return tvm::compute(output_shape, l, name); + return tvm::compute(ComputeMatMulShape(A_shape, B_shape), l, name); } } +tvm::Array +ComputeMatMulShape( + const tvm::Array& A_shape, + const tvm::Array& B_shape) { + auto a_rank = A_shape.size(); + auto b_rank = B_shape.size(); + tvm::Array output_shape; + int64_t output_rank = std::max(a_rank, b_rank); + MTI_ASSERT(tvm::ir::Equal(A_shape[a_rank - 1], B_shape[b_rank - 2])); + for (int64_t i = 0; i < output_rank - 2; i++) { + tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1); + bool broadcasted = + BroadcastDim(A_shape, i, output_rank, broadcasted_dim) && + BroadcastDim(B_shape, i, output_rank, broadcasted_dim); + MTI_ASSERT(broadcasted); + output_shape.push_back(broadcasted_dim); + } + output_shape.push_back(A_shape[a_rank - 2]); + output_shape.push_back(B_shape[b_rank - 1]); + return output_shape; +} + } // namespace tvm_codegen } // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.h b/onnxruntime/core/codegen/mti/math/matmul_ops.h index c149486a87fab..7180b4f6d81e5 100644 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.h +++ b/onnxruntime/core/codegen/mti/math/matmul_ops.h @@ -8,6 +8,11 @@ namespace onnxruntime { namespace tvm_codegen { +tvm::Array +ComputeMatMulShape( + const tvm::Array& A_shape, + const tvm::Array& B_shape); + tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, const std::string& name = "matmul2d"); tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string& name = "matmul"); diff --git a/onnxruntime/core/codegen/mti/math/unary_ops.cc b/onnxruntime/core/codegen/mti/math/unary_ops.cc index 7f45a9115fb0b..a9b18072988b2 100644 --- a/onnxruntime/core/codegen/mti/math/unary_ops.cc +++ b/onnxruntime/core/codegen/mti/math/unary_ops.cc @@ -21,7 +21,9 @@ tvm::Tensor Abs(const tvm::Tensor& X, const std::string& name) { } tvm::Tensor Affine(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - return Rename(alpha * X + beta, name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); + return Rename(alphaExpr * X + betaExpr, name); } tvm::Tensor Ceil(const tvm::Tensor& X, const std::string& name) { @@ -39,7 +41,8 @@ tvm::Tensor Clip(const tvm::Tensor& X, float min_value, float max_value, const s } tvm::Tensor Elu(const tvm::Tensor& X, float alpha, const std::string& name) { - return Rename(Relu(X) - alpha * Relu(1 - Exp(X)), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + return Rename(Relu(X) - alphaExpr * Relu(1 - Exp(X)), name); } tvm::Tensor Exp(const tvm::Tensor& X, const std::string& name) { @@ -56,11 +59,14 @@ tvm::Tensor Floor(const tvm::Tensor& X, const std::string& name) { } tvm::Tensor HardSigmoid(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - return maximum(0, minimum(1, alpha * X + beta), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); + return maximum(0, minimum(1, alphaExpr * X + betaExpr), name); } tvm::Tensor LeakyRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - return Rename(Relu(X) - alpha * Relu(0 - X), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + return Rename(Relu(X) - alphaExpr * Relu(0 - X), name); } tvm::Tensor Log(const tvm::Tensor& X, const std::string& name) { @@ -77,7 +83,9 @@ tvm::Tensor Neg(const tvm::Tensor& X, const std::string& name) { } tvm::Tensor ParametricSoftplus(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - return Rename(alpha * Softplus(beta * X), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); + return Rename(alphaExpr * Softplus(betaExpr * X), name); } tvm::Tensor Reciprocal(const tvm::Tensor& X, const std::string& name) { @@ -89,11 +97,15 @@ tvm::Tensor Relu(const tvm::Tensor& X, const std::string& name) { } tvm::Tensor ScaledTanh(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - return Rename(alpha * Tanh(beta * X), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); + return Rename(alphaExpr * Tanh(betaExpr * X), name); } tvm::Tensor Selu(const tvm::Tensor& X, float alpha, float gamma, const std::string& name) { - return Rename(gamma * (-alpha * Relu(1 - Exp(X)) + Relu(X)), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + tvm::Expr gammaExpr = tvm::make_const(X->dtype, gamma); + return Rename(gammaExpr * (-alphaExpr * Relu(1 - Exp(X)) + Relu(X)), name); } tvm::Tensor Sigmoid(const tvm::Tensor& X, const std::string& name) { @@ -135,7 +147,8 @@ tvm::Tensor Tanh(const tvm::Tensor& X, const std::string& name) { } tvm::Tensor ThresholdedRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - return topi::where(greater(X, alpha), X, topi::full_like(X, tvm::make_zero(X->dtype)), name); + tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); + return topi::where(greater(X, alphaExpr), X, topi::full_like(X, tvm::make_zero(X->dtype)), name); } } // namespace tvm_codegen diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc index e905a34432a6e..3696deea22b3c 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc @@ -4,6 +4,7 @@ #include "core/codegen/mti/mti_tvm_utils.h" #include "core/codegen/common/settings.h" +#include "core/codegen/mti/tensor/reshape_ops.h" #include #include @@ -158,5 +159,38 @@ bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_ra return true; } +tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name) { + // note that currently TVM StorageFlatten creates strides like max(symbolic_dim, 1) + // which is not zero when checking symbolic_dim - max(symbolic_dim, 1) + // then triggers error like: Trying to bind compact buffer to strided one + // here's a workaround to reshape inputs to avoid that + tvm::Array fixed_inputs; + for (size_t idx_input = 0; idx_input < inputs.size(); ++idx_input) { + const auto& input = inputs[idx_input]; + tvm::Array fixed_shape; + if (input->shape.size() > 0) { + // stride compute does not use dim 0, so directly push to fixed_shape + fixed_shape.push_back(input->shape[0]); + bool need_fix = false; + for (size_t idx_dim = 1; idx_dim < input->shape.size(); ++idx_dim) { + const auto& dim = input->shape[idx_dim]; + if (tvm::as_const_int(dim) == nullptr) { + fixed_shape.push_back(tvm::max(dim, tvm::make_const(HalideIR::Int(32), 1))); + need_fix = true; + } else { + fixed_shape.push_back(dim); + } + } + if (need_fix) { + fixed_inputs.push_back(tvm_codegen::Reshape(input, fixed_shape, name + "_" + std::to_string(idx_input))); + continue; + } + } + // no fix needed + fixed_inputs.push_back(input); + } + return fixed_inputs; +} + } // namespace tvm_codegen } // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h index 3f65658554f2c..034a4fe28b23a 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.h @@ -60,5 +60,8 @@ inline int64_t HandleNegativeAxis(int64_t axis, int64_t rank) { return axis = axis < 0 ? (axis + rank) : axis; } +// Helper function to workaround tvm ExternOp issue when input has symbolic dimensions +tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name = "make_inputs_for_extern"); + } // namespace tvm_codegen } // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc index 60841d049e734..6f66b1f1a2afb 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc @@ -16,19 +16,19 @@ Status GENERIC_OP_IR_CREATOR_CLASS(MatMulInteger)::Evaluate( const Node& node, CodeGenContext& ctx_codegen, tvm::Array& outputs) { - const auto& lhs_tensor = inputs[0]; - const auto& rhs_tensor = inputs[1]; + const auto& A = inputs[0]; + const auto& B = inputs[1]; auto& name = node.Name(); // A generic path, cast to int32 // Support skipped trailing inputs - auto lhs = (node.InputDefs().size() >= 3 && node.InputDefs()[2]->Exists()) - ? Sub(Cast(lhs_tensor, HalideIR::Int(32)), Cast(inputs[2], HalideIR::Int(32))) - : Cast(lhs_tensor, HalideIR::Int(32)); - auto rhs = (node.InputDefs().size() >= 4 && node.InputDefs()[3]->Exists()) - ? Sub(Cast(rhs_tensor, HalideIR::Int(32)), Cast(inputs[3], HalideIR::Int(32))) - : Cast(rhs_tensor, HalideIR::Int(32)); - tvm::Tensor Y = MatMul(lhs, rhs, name + "_MatMulInteger"); + auto A_Int32 = (node.InputDefs().size() >= 3 && node.InputDefs()[2]->Exists()) + ? Sub(Cast(A, HalideIR::Int(32)), Cast(inputs[2], HalideIR::Int(32))) + : Cast(A, HalideIR::Int(32)); + auto B_Int32 = (node.InputDefs().size() >= 4 && node.InputDefs()[3]->Exists()) + ? Sub(Cast(B, HalideIR::Int(32)), Cast(inputs[3], HalideIR::Int(32))) + : Cast(B, HalideIR::Int(32)); + tvm::Tensor Y = MatMul(A_Int32, B_Int32, name + "_MatMulInteger"); outputs.push_back(Y); return Status::OK(); } diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc index 46adb7e984f2d..3b6a9a76f0723 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc @@ -29,7 +29,8 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Crop)::Evaluate( ORT_ENFORCE(attrs.GetAttrs("border", border).IsOK()); // scale is optional and status is false when omit - attrs.GetAttrs("scale", scale); + bool is_ok = attrs.GetAttrs("scale", scale).IsOK(); + ORT_UNUSED_PARAMETER(is_ok); if (border.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc index f4d7bb1da5e97..43999ebd1f465 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc @@ -21,20 +21,22 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Transpose)::Evaluate( size_t input_0_shape_rank = inputs[0]->shape.size(); std::vector permute; - attrs.GetAttrs("perm", permute); + bool is_ok = attrs.GetAttrs("perm", permute).IsOK(); if (permute.size() != 0 && permute.size() != input_0_shape_rank) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose: Incorrect permute size"); std::vector default_permute; const std::vector* perm; - if (permute.size() > 0) { - perm = &permute; - } else { + // either we don't have perm attribute or the perm attribute is empty + bool use_default_perm = !is_ok || permute.size() == 0; + if (use_default_perm) { default_permute.resize(input_0_shape_rank); for (size_t i = 0; i < input_0_shape_rank; ++i) { default_permute[i] = gsl::narrow(input_0_shape_rank - 1 - i); } perm = &default_permute; + } else { + perm = &permute; } tvm::Tensor Y = Transpose(inputs[0], ToTvmArrayInt(*perm), node.Name() + "_Transpose"); diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h index fe2648462e4f5..e29c4a9f20767 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h +++ b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h @@ -29,7 +29,7 @@ class OpIRDispatcher : public codegen::DispatcherBase { OpIRDispatcher(const std::string& name) : DispatcherBase(name) {} - ~OpIRDispatcher() = default; + virtual ~OpIRDispatcher() = default; virtual OpIRCreator* Find(const Node&) = 0; diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h index 413e0fb504e89..d022497c77f7e 100644 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h +++ b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h @@ -58,7 +58,7 @@ class TVMScheduleDispatcher : public codegen::DispatcherBase { TVMScheduleDispatcher(const std::string& name) : DispatcherBase(name) {} - ~TVMScheduleDispatcher() = default; + virtual ~TVMScheduleDispatcher() = default; virtual Scheduler* Find(const tvm::Tensor&, const Node*, diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc index f7906b71e1189..670a540404c94 100644 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc +++ b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc @@ -100,7 +100,7 @@ tvm::Expr ShapeDimToTvmDim(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim #ifdef CODEGEN_ENABLE_PROFILER struct event_in_bracket_and_id { bool in_bracket; - int id; + size_t id; }; std::unordered_map g_codegen_profiler_event_ids; std::vector> g_codegen_profiler_events(1024); @@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.profile_event") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { DLTensor* X = args[0]; DLTensor* Y = args[1]; - int event_id = args[2]; + size_t event_id = args[2]; bool is_begin = args[3]; if (!is_begin) { DCHECK(event_id < g_codegen_profiler_event_ids.size()); @@ -120,7 +120,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.profile_event") } { - CODEGEN_PROFILER_EVENT(profile_stub); + CODEGEN_PROFILER_EVENT("profile_stub"); int64_t elem_count = 1; for (int i = 0; i < X->ndim; ++i) { elem_count *= X->shape[i]; @@ -141,7 +141,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.profile_event") }); tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name) { - int event_id; + size_t event_id; if (g_codegen_profiler_event_ids.count(event_name) == 0) { event_id = g_codegen_profiler_event_ids.size(); ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); @@ -157,7 +157,7 @@ tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name) { return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), topi::detail::pack_buffer(ins[0]), topi::detail::pack_buffer(outs[0]), - event_id, + gsl::narrow(event_id), true}); }, event_name + "_begin", "", {})[0]; @@ -166,7 +166,7 @@ tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name) { tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name) { ORT_ENFORCE(g_codegen_profiler_event_ids.at(event_name).in_bracket); g_codegen_profiler_event_ids.at(event_name).in_bracket = false; - int event_id = g_codegen_profiler_event_ids.at(event_name).id; + size_t event_id = g_codegen_profiler_event_ids.at(event_name).id; ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); ORT_ENFORCE(g_codegen_profiler_events[event_id].first == event_name); return topi::detail::make_extern( @@ -175,7 +175,7 @@ tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name) { return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), topi::detail::pack_buffer(ins[0]), topi::detail::pack_buffer(outs[0]), - event_id, + gsl::narrow(event_id), false}); }, event_name + "_end", "", {})[0]; diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h index bcd9b229b5a3d..af61641a74937 100644 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h +++ b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h @@ -30,7 +30,7 @@ class WeightLayout { int input_dim, float pad_zero); - ~WeightLayout() = default; + virtual ~WeightLayout() = default; // Return a CoordTransFunc from actual (transformed) coordinate to normial (original) coordinate virtual CoordTransFunc ToNominal(const tvm::Tensor& X) const = 0; diff --git a/onnxruntime/core/common/logging/capture.cc b/onnxruntime/core/common/logging/capture.cc index 016ddb9fc06be..6223d2ca70ec2 100644 --- a/onnxruntime/core/common/logging/capture.cc +++ b/onnxruntime/core/common/logging/capture.cc @@ -27,16 +27,26 @@ void Capture::ProcessPrintf(msvc_printf_check const char* format, va_list args) char message_buffer[kMaxMessageSize]; const auto message = gsl::make_span(message_buffer); + bool error = false; + bool truncated = false; + #if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__)) + errno = 0; const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args); + if (nbrcharacters < 0) { + error = errno != 0; + truncated = !error; + } #else const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args); + error = nbrcharacters < 0; + truncated = nbrcharacters > message.size(); #endif - if (nbrcharacters <= 0) { + if (error) { stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message"; stream_ << '"' << format << '"' << std::endl; - } else if (nbrcharacters > message.size()) { + } else if (truncated) { stream_ << message.data() << kTruncatedWarningText; } else { stream_ << message.data(); diff --git a/onnxruntime/core/common/profiler.cc b/onnxruntime/core/common/profiler.cc index d8eb1b2354027..1fa0577a676b9 100644 --- a/onnxruntime/core/common/profiler.cc +++ b/onnxruntime/core/common/profiler.cc @@ -7,6 +7,16 @@ namespace onnxruntime { namespace profiling { using namespace std::chrono; +#ifdef ENABLE_STATIC_PROFILER_INSTANCE +Profiler* Profiler::instance_ = nullptr; + +profiling::Profiler::~Profiler() { + instance_ = nullptr; +} +#else +profiling::Profiler::~Profiler() {} +#endif + ::onnxruntime::TimePoint profiling::Profiler::StartTime() const { return std::chrono::high_resolution_clock::now(); } @@ -14,6 +24,14 @@ ::onnxruntime::TimePoint profiling::Profiler::StartTime() const { void Profiler::Initialize(const logging::Logger* session_logger) { ORT_ENFORCE(session_logger != nullptr); session_logger_ = session_logger; +#ifdef ENABLE_STATIC_PROFILER_INSTANCE + // In current design, profiler instance goes with inference session. Since it's possible to have + // multiple inference sessions, profiler by definition is not singleton. However, in performance + // debugging, it would be helpful to access profiler in code that have no access to inference session, + // which is why we have this pseudo-singleton implementation here for debugging in single inference session. + ORT_ENFORCE(instance_ == nullptr, "Static profiler instance only works with single session"); + instance_ = this; +#endif } void Profiler::StartProfiling(const logging::Logger* custom_logger) { diff --git a/onnxruntime/core/common/profiler.h b/onnxruntime/core/common/profiler.h index 48ecf5747467a..815695a4fa4ed 100644 --- a/onnxruntime/core/common/profiler.h +++ b/onnxruntime/core/common/profiler.h @@ -13,6 +13,10 @@ namespace onnxruntime { namespace profiling { +// uncomment the macro below, or use -DENABLE_STATIC_PROFILER_INSTANCE for debugging +// note that static profiler instance only works with single session +//#define ENABLE_STATIC_PROFILER_INSTANCE + /** * Main class for profiling. It continues to accumulate events and produce * a corresponding "complete event (X)" in "chrome tracing" format. @@ -23,6 +27,8 @@ class Profiler { /// Even this function is marked as noexcept, the code inside it may throw exceptions Profiler() noexcept {}; //NOLINT + ~Profiler(); + /* Initializes Profiler with the session logger to log framework specific messages */ @@ -67,6 +73,15 @@ class Profiler { */ std::string EndProfiling(); + static Profiler& Instance() { +#ifdef ENABLE_STATIC_PROFILER_INSTANCE + ORT_ENFORCE(instance_ != nullptr); + return *instance_; +#else + ORT_THROW("Static profiler instance is not enabled, please compile with -DENABLE_STATIC_PROFILER_INSTANCE"); +#endif + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Profiler); @@ -82,6 +97,10 @@ class Profiler { bool max_events_reached{false}; static constexpr size_t max_num_events_ = 1000000; bool profile_with_logger_{false}; + +#ifdef ENABLE_STATIC_PROFILER_INSTANCE + static Profiler* instance_; +#endif }; } // namespace profiling diff --git a/onnxruntime/core/common/task_thread_pool.h b/onnxruntime/core/common/task_thread_pool.h deleted file mode 100644 index 1cc0d64ecfd6b..0000000000000 --- a/onnxruntime/core/common/task_thread_pool.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* -Changed to use std::packaged_task instead of std::function so exceptions can be propagated. - -This also allows the task threadpool to be shared across multiple operators as the caller -can keep a container of the packaged_task futures to check when they have completed. Calling -WaitWorkComplete in that use case is invalid as there may be other concurrent usage of the -threadpool. - -Example of that usage: - - std::vector> task_results{}; - - for (...) { - std::packaged_task task{std::bind(lambda, i)}; - task_results.push_back(task.get_future()); - task_thread_pool.RunTask(std::move(task)); - } - - try { - // wait for all and propagate any exceptions - for (auto& future : task_results) - future.get(); - } catch (const std::exception& ex) { - ... - throw; - } - -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/logging/logging.h" -#include "core/platform/ort_mutex.h" - -namespace onnxruntime { - -class TaskThreadPool { - private: - struct task_element_t { - bool run_with_id; - std::packaged_task no_id; - std::packaged_task with_id; - - task_element_t(task_element_t&& other) noexcept { - run_with_id = other.run_with_id; - no_id = std::move(other.no_id); - with_id = std::move(other.with_id); - } - - explicit task_element_t(std::packaged_task&& f) - : run_with_id(false), no_id(std::move(f)) {} - - explicit task_element_t(std::packaged_task&& f) - : run_with_id(true), with_id(std::move(f)) {} - }; - - std::queue tasks_; - std::vector threads_; - OrtMutex mutex_; - OrtCondVar condition_; - OrtCondVar completed_; - bool running_; - bool complete_; - std::size_t available_; - std::size_t total_; - - public: - /// @brief Constructor. - explicit TaskThreadPool(std::size_t pool_size) - : threads_(pool_size), running_(true), complete_(true), available_(pool_size), total_(pool_size) { - for (std::size_t i = 0; i < pool_size; ++i) { - threads_[i] = std::thread(std::bind(&TaskThreadPool::MainLoop, this, i)); - } - } - - /// @brief Destructor. - ~TaskThreadPool() { - // Set running flag to false then notify all threads. - { - std::unique_lock lock(mutex_); - running_ = false; - condition_.notify_all(); - } - - try { - for (auto& t : threads_) { - t.join(); - } - } - // Suppress all exceptions. - catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "Exception joining threads in TaskThreadPool: " << ex.what(); - } - } - - int NumThreads() const { - return (int)threads_.size(); - } - - // This thread pool does not support ids - int CurrentThreadId() const { - return -1; - } - - void RunTask(std::packaged_task&& task) { - std::unique_lock lock(mutex_); - - // Set task and signal condition variable so that a worker thread will - // wake up and use the task. - tasks_.push(task_element_t(std::move(task))); - complete_ = false; - condition_.notify_one(); - } - - void RunTaskWithID(std::packaged_task&& task) { - std::unique_lock lock(mutex_); - - // Set task and signal condition variable so that a worker thread will - // wake up and use the task. - tasks_.push(task_element_t(std::move(task))); - complete_ = false; - condition_.notify_one(); - } - - /// @brief Wait for queue to be empty - void WaitWorkComplete() { - std::unique_lock lock(mutex_); - while (!complete_) - completed_.wait(lock); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TaskThreadPool); - - /// @brief Entry point for pool threads. - void MainLoop(std::size_t index) { - while (running_) { - // Wait on condition variable while the task is empty and - // the pool is still running. - std::unique_lock lock(mutex_); - while (tasks_.empty() && running_) { - condition_.wait(lock); - } - - // If pool is no longer running, break out of loop. - if (!running_) break; - - // Copy task locally and remove from the queue. This is - // done within its own scope so that the task object is - // destructed immediately after running the task. This is - // useful in the event that the function contains - // shared_ptr arguments bound via bind. - { - auto task = std::move(tasks_.front()); - tasks_.pop(); - // Decrement count, indicating thread is no longer available. - --available_; - - lock.unlock(); - - // Run the task. - try { - if (task.run_with_id) { - task.with_id(index); - } else { - task.no_id(); - } - } catch (const std::exception& /*ex*/) { - // LOGS_DEFAULT(ERROR) << "Exception running TaskThreadPool task: " << ex.what(); - throw; - } - - // Update status of empty, maybe - // Need to recover the lock first - lock.lock(); - - // Increment count, indicating thread is available. - ++available_; - if (tasks_.empty() && available_ == total_) { - complete_ = true; - completed_.notify_one(); - } - } - } // while running_ - } -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 07305a41d0645..6cdcb3add7cf0 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -6,174 +6,31 @@ #include -#ifdef USE_EIGEN_THREADPOOL -#if defined(_MSC_VER) -#pragma warning(disable : 4267) -#endif - #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4267) #endif -#include +#include #if defined(__GNUC__) #pragma GCC diagnostic pop -#endif #else -#include "task_thread_pool.h" +#pragma warning(pop) #endif +using Eigen::Barrier; + namespace onnxruntime { namespace concurrency { - -// TODO: This is temporarily taken from Eigen until we upgrade its version. -// Barrier is an object that allows one or more threads to wait until -// Notify has been called a specified number of times. -class Barrier { - public: - Barrier(unsigned int count) : state_(count << 1), notified_(false) { - assert(((count << 1) >> 1) == count); - } - ~Barrier() { - assert((state_ >> 1) == 0); - } - - void Notify() { - unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; - if (v != 1) { - assert(((v + 2) & ~1) != 0); - return; // either count has not dropped to 0, or waiter is not waiting - } - std::unique_lock l(mu_); - assert(!notified_); - notified_ = true; - cv_.notify_all(); - } - - void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) return; - std::unique_lock l(mu_); - while (!notified_) { - cv_.wait(l); - } - } - - private: - std::mutex mu_; - std::condition_variable cv_; - std::atomic state_; // low bit is waiter flag - bool notified_; -}; - -#ifdef USE_EIGEN_THREADPOOL -class ThreadPool::Impl : public Eigen::ThreadPool { - public: - Impl(const std::string& name, int num_threads) - : Eigen::ThreadPool(num_threads) { - ORT_UNUSED_PARAMETER(name); - } - - void ParallelFor(int32_t total, std::function fn) { - // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism - // We will simply rely on the work queue and stealing in the short term. - Barrier barrier(static_cast(total - 1)); - std::function handle_iteration = [&barrier, &fn](int iteration) { - fn(iteration); - barrier.Notify(); - }; - - for (int32_t id = 1; id < total; ++id) { - Schedule([=, &handle_iteration]() { handle_iteration(id); }); - } - - fn(0); - barrier.Wait(); - } - - void ParallelForRange(int64_t first, int64_t last, std::function fn) { - // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism - // We will simply rely on the work queue and stealing in the short term. - Barrier barrier(static_cast(last - first)); - std::function handle_range = [&barrier, &fn](int64_t first, int64_t last) { - fn(first, last); - barrier.Notify(); - }; - - for (int64_t id = first + 1; id <= last; ++id) { - Schedule([=, &handle_range]() { handle_range(id, id + 1); }); - } - - fn(first, first + 1); - barrier.Wait(); - } -}; -#else -class ThreadPool::Impl : public TaskThreadPool { - public: - Impl(const std::string& name, int num_threads) - : TaskThreadPool(num_threads) { - ORT_UNUSED_PARAMETER(name); - } - - void Schedule(std::function fn) { - std::packaged_task task(fn); - RunTask(std::move(task)); - } - - void ParallelFor(int32_t total, std::function fn) { -#ifdef USE_OPENMP -#pragma omp parallel for - for (int32_t id = 0; id < total; ++id) { - fn(id); - } -#else - Barrier barrier(static_cast(total - 1)); - std::function handle_iteration = [&barrier, &fn](int iteration) { - fn(iteration); - barrier.Notify(); - }; - for (int32_t id = 1; id < total; ++id) { - std::packaged_task task(std::bind(handle_iteration, id)); - RunTask(std::move(task)); - } - fn(0); - barrier.Wait(); -#endif - } - - void ParallelForRange(int64_t first, int64_t last, std::function fn) { -#ifdef USE_OPENMP -#pragma omp parallel for - for (int64_t id = first; id < last; ++id) { - fn(id, id + 1); - } -#else - Barrier barrier(static_cast(last - first)); - std::function handle_iteration = [&barrier, &fn](int64_t first, int64_t last) { - fn(first, last); - barrier.Notify(); - }; - for (int64_t id = first + 1; id < last; ++id) { - std::packaged_task task(std::bind(handle_iteration, id, id + 1)); - RunTask(std::move(task)); - } - fn(first, first + 1); - barrier.Wait(); -#endif - } -}; -#endif - // // ThreadPool // -ThreadPool::ThreadPool(const std::string& name, int num_threads) - : impl_(std::make_unique(name, num_threads)) { -} +ThreadPool::ThreadPool(const std::string&, int num_threads) : impl_(num_threads) {} -void ThreadPool::Schedule(std::function fn) { impl_->Schedule(fn); } +void ThreadPool::Schedule(std::function fn) { impl_.Schedule(fn); } void ThreadPool::ParallelFor(int32_t total, std::function fn) { if (total <= 0) return; @@ -183,7 +40,20 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { return; } - impl_->ParallelFor(total, fn); + // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism + // We will simply rely on the work queue and stealing in the short term. + Barrier barrier(static_cast(total - 1)); + std::function handle_iteration = [&barrier, &fn](int iteration) { + fn(iteration); + barrier.Notify(); + }; + + for (int32_t id = 1; id < total; ++id) { + Schedule([=, &handle_iteration]() { handle_iteration(id); }); + } + + fn(0); + barrier.Wait(); } void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function fn) { @@ -193,18 +63,28 @@ void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::functionParallelForRange(first, last, fn); + // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism + // We will simply rely on the work queue and stealing in the short term. + Barrier barrier(static_cast(last - first)); + std::function handle_range = [&barrier, &fn](int64_t first, int64_t last) { + fn(first, last); + barrier.Notify(); + }; + + for (int64_t id = first + 1; id <= last; ++id) { + Schedule([=, &handle_range]() { handle_range(id, id + 1); }); + } + + fn(first, first + 1); + barrier.Wait(); } // void ThreadPool::SetStealPartitions(const std::vector>& partitions) { // impl_->SetStealPartitions(partitions); // } -int ThreadPool::NumThreads() const { return impl_->NumThreads(); } - -int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); } - -ThreadPool::~ThreadPool() {} +int ThreadPool::NumThreads() const { return impl_.NumThreads(); } +int ThreadPool::CurrentThreadId() const { return impl_.CurrentThreadId(); } } // namespace concurrency } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 552d702b80e7c..5046b6e7b5fd6 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -338,12 +338,19 @@ class PlannerImpl { // Initialize execution plan: plan_.execution_plan.reserve(num_graph_nodes); + // Initialize node_has_fence. + plan_.node_has_fence.resize(graph_viewer_.MaxNodeIndex()); + // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); } Status ComputeUseCounts() { // Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model + std::unordered_set graph_inputs; + for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) { + graph_inputs.insert(graph_input->Name()); + } for (auto graph_input : graph_viewer_.GetInputs()) { OrtValueIndex index = Index(graph_input->Name()); @@ -368,15 +375,7 @@ class PlannerImpl { for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { auto pnode = graph_viewer_.GetNode(step.node_index); if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); - for (auto node_input : pnode->InputDefs()) { - if (node_input->Exists()) - UseCount(node_input->Name())++; - } - for (auto node_input : pnode->ImplicitInputDefs()) { - if (node_input->Exists()) - UseCount(node_input->Name())++; - } // Identify where each output of this node should be allocated. // This is determined by the opkernel bound to the node. const KernelCreateInfo* kernel_create_info = nullptr; @@ -391,31 +390,45 @@ class PlannerImpl { if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")"; return Status(ONNXRUNTIME, FAIL, errormsg.str()); } - auto exec_provider = execution_providers_.Get(*pnode); if (exec_provider == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ", pnode->GetExecutionProviderType()); } - auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); + // increment UseCount and add location information if applicable for the provided input def + auto process_input = [&graph_inputs, &exec_provider, &p_kernelDef, this](const NodeArg& input, size_t arg_idx) { + const auto& name = input.Name(); + UseCount(name)++; + + // If it's a graph input or outer scope node arg, set its plan. + // NOTE: Copy nodes should have already been added if a graph input is fed as input + // to nodes assigned to different providers. + if (graph_inputs.find(name) != graph_inputs.cend() || + std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(), + [&name](const NodeArg* value) { + return value && value->Name() == name; + }) != outer_scope_node_args_.cend()) { + OrtValueIndex index = Index(name); + plan_.SetLocation(static_cast(index), + exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(arg_idx))->Info()); + } + + return Status::OK(); + }; + + ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->InputDefs(), process_input)); + ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->ImplicitInputDefs(), process_input)); + auto outputs = pnode->OutputDefs(); auto num_outputs = outputs.size(); - for (size_t i = 0; i < num_outputs; ++i) { auto* node_output = outputs[i]; if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); ProcessDef(index, node_output); ++UseCount(index); - if (strcmp(default_allocator_info.name, CPU) != 0) { - // By default, outputs of this node are allocated on the default device allocator, - // except for outputs marked for allocation in MemoryType: - auto memory_type = p_kernelDef->OutputMemoryType(i); - plan_.SetLocation(static_cast(index), memory_type == OrtMemTypeDefault - ? default_allocator_info - : exec_provider->GetAllocator(0, memory_type)->Info()); - } + plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info()); } // if sync is needed, mark allocation plan as create_fence_if_async=true // note that the input arg may come from an execution provider (i.e. CPU) that does not support async, @@ -585,6 +598,51 @@ class PlannerImpl { return Status::OK(); } + // Whether a given NodeArg has fence or not. + // If the buffer is reused, need to check whether original OrtValue has fence or not. + bool HasFence(const onnxruntime::NodeArg* arg) { + bool has_fence = false; + if (arg && arg->Exists()) { + OrtValueIndex index = Index(arg->Name()); + AllocPlanPerValue& value_plan = AllocPlan(index); + + has_fence = value_plan.create_fence_if_async; + if (value_plan.alloc_kind == AllocKind::kReuse) + { + // Buffer reused, check original buffer to see if fence is shared. + has_fence = has_fence || AllocPlan(value_plan.reused_buffer).create_fence_if_async; + } + } + + return has_fence; + } + + // Compute fence check. Set has_fence flag if either one of inputs, implicit inputs or outputs of a given node has fence. + Status ComputeFenceCheck() { + + for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { + auto pnode = graph_viewer_.GetNode(step.node_index); + if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); + + bool has_fence = false; + for (auto node_input : pnode->InputDefs()) { + has_fence = has_fence || HasFence(node_input); + } + + for (auto node_input : pnode->ImplicitInputDefs()) { + has_fence = has_fence || HasFence(node_input); + } + + for (auto node_output : pnode->OutputDefs()) { + has_fence = has_fence || HasFence(node_output); + } + + plan_.node_has_fence[step.node_index] = has_fence; + } + + return Status::OK(); + } + // Convert information in a freelist (about which ml-value becomes free when) into // a deallocation plan in the format required in an ExecutionPlan void GenerateDeallocationPlan() { @@ -642,6 +700,9 @@ Status PlannerImpl::CreatePlan() { // determine sharing/reuse among ml-values ORT_RETURN_IF_ERROR(ComputeReusePlan()); + // Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan. + ORT_RETURN_IF_ERROR(ComputeFenceCheck()); + // convert information in the freelist_ into a deallocation plan in required format GenerateDeallocationPlan(); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index b8847a00801c3..800a2b898526c 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -3,86 +3,67 @@ #include "core/framework/allocator.h" #include "core/framework/allocatormgr.h" -#include "core/mlas/inc/mlas.h" +#include "core/framework/utils.h" #include #include namespace onnxruntime { void* CPUAllocator::Alloc(size_t size) { - if (size <= 0) - return nullptr; - void* p; - size_t alignment = MlasGetPreferredBufferAlignment(); -#if _MSC_VER - p = _aligned_malloc(size, alignment); - if (p == nullptr) throw std::bad_alloc(); -#elif defined(_LIBCPP_SGX_CONFIG) - p = memalign(alignment, size); - if (p == nullptr) throw std::bad_alloc(); -#else - int ret = posix_memalign(&p, alignment, size); - if (ret != 0) throw std::bad_alloc(); -#endif - return p; + return utils::DefaultAlloc(size); } void CPUAllocator::Free(void* p) { -#if _MSC_VER - _aligned_free(p); -#else - free(p); -#endif + utils::DefaultFree(p); } -const OrtAllocatorInfo& CPUAllocator::Info() const { - return *allocator_info_; -} +const OrtAllocatorInfo& CPUAllocator::Info() const { return *allocator_info_; } } // namespace onnxruntime -std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info) { - return (out << info.ToString()); -} +std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info) { return (out << info.ToString()); } ORT_API_STATUS_IMPL(OrtCreateAllocatorInfo, _In_ const char* name1, OrtAllocatorType type, int id1, OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out) { if (strcmp(name1, onnxruntime::CPU) == 0) { *out = new OrtAllocatorInfo(name1, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0) { - *out = new OrtAllocatorInfo(name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + *out = new OrtAllocatorInfo( + name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { - *out = new OrtAllocatorInfo(name1, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), id1, mem_type1); + *out = new OrtAllocatorInfo( + name1, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), + id1, mem_type1); } else { return OrtCreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); } return nullptr; } -ORT_API(void, OrtReleaseAllocatorInfo, _Frees_ptr_opt_ OrtAllocatorInfo* p) { - delete p; -} +ORT_API(void, OrtReleaseAllocatorInfo, _Frees_ptr_opt_ OrtAllocatorInfo* p) { delete p; } -ORT_API_STATUS_IMPL(OrtAllocatorInfoGetName, _In_ OrtAllocatorInfo* ptr, _Out_ const char** out) { +ORT_API_STATUS_IMPL(OrtAllocatorInfoGetName, _In_ const OrtAllocatorInfo* ptr, _Out_ const char** out) { *out = ptr->name; return nullptr; } -ORT_API_STATUS_IMPL(OrtAllocatorInfoGetId, _In_ OrtAllocatorInfo* ptr, _Out_ int* out) { +ORT_API_STATUS_IMPL(OrtAllocatorInfoGetId, _In_ const OrtAllocatorInfo* ptr, _Out_ int* out) { *out = ptr->id; return nullptr; } -ORT_API_STATUS_IMPL(OrtAllocatorInfoGetMemType, _In_ OrtAllocatorInfo* ptr, _Out_ OrtMemType* out) { +ORT_API_STATUS_IMPL(OrtAllocatorInfoGetMemType, _In_ const OrtAllocatorInfo* ptr, _Out_ OrtMemType* out) { *out = ptr->mem_type; return nullptr; } -ORT_API_STATUS_IMPL(OrtAllocatorInfoGetType, _In_ OrtAllocatorInfo* ptr, _Out_ OrtAllocatorType* out) { +ORT_API_STATUS_IMPL(OrtAllocatorInfoGetType, _In_ const OrtAllocatorInfo* ptr, _Out_ OrtAllocatorType* out) { *out = ptr->type; return nullptr; } -ORT_API_STATUS_IMPL(OrtCompareAllocatorInfo, _In_ const OrtAllocatorInfo* info1, _In_ const OrtAllocatorInfo* info2, _Out_ int* out) { +ORT_API_STATUS_IMPL(OrtCompareAllocatorInfo, _In_ const OrtAllocatorInfo* info1, _In_ const OrtAllocatorInfo* info2, + _Out_ int* out) { *out = (*info1 == *info2) ? 0 : -1; return nullptr; } diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 664f6fa72a04b..bdc6496c63205 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -244,7 +244,7 @@ class BFCArena : public IArenaAllocator { ~AllocationRegion() { delete[] handles_; } - AllocationRegion(AllocationRegion&& other) { Swap(other); } + AllocationRegion(AllocationRegion&& other) noexcept { Swap(other); } AllocationRegion& operator=(AllocationRegion&& other) { Swap(other); diff --git a/onnxruntime/core/framework/callback.cc b/onnxruntime/core/framework/callback.cc index 414b7ad0d2dc8..deb4d1e277d47 100644 --- a/onnxruntime/core/framework/callback.cc +++ b/onnxruntime/core/framework/callback.cc @@ -1,12 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/common/callback.h" +#include "core/framework/callback.h" -ORT_API(void, OrtRunCallback, _Frees_ptr_opt_ OrtCallback* f){ - if(f == nullptr) return; - if(f->f != nullptr) { +namespace onnxruntime { +void OrtRunCallback(OrtCallback* f) noexcept { + if (f == nullptr) return; + if (f->f != nullptr) { f->f(f->param); delete f; } } +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/callback.h b/onnxruntime/core/framework/callback.h new file mode 100644 index 0000000000000..63cb3b6fcf586 --- /dev/null +++ b/onnxruntime/core/framework/callback.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { +struct OrtCallback { + void (*f)(void* param) noexcept; + void* param; +}; + +/** + * f will be freed in this call + */ +void OrtRunCallback(OrtCallback* f) noexcept; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index b41a59518cfaf..a372e52058036 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -6,6 +6,10 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" +#ifdef MICROSOFT_AUTOML +#include "automl_ops/automl_types.h" +#endif + #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" @@ -285,6 +289,9 @@ class DataTypeRegistry { DataTypeRegistry() { RegisterAllProtos([this](MLDataType mltype) { RegisterDataType(mltype); }); +#ifdef MICROSOFT_AUTOML + automl::RegisterAutoMLTypes([this](MLDataType mltype) { RegisterDataType(mltype); }); +#endif } ~DataTypeRegistry() = default; @@ -887,6 +894,40 @@ ORT_REGISTER_NON_ONNX_TYPE(uint64_t); ORT_REGISTER_NON_ONNX_TYPE(MLFloat16); ORT_REGISTER_NON_ONNX_TYPE(BFloat16); +const std::vector& DataTypeImpl::AllFixedSizeTensorExceptHalfTypes() { + static std::vector all_fixed_size_tensor_types = + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + + return all_fixed_size_tensor_types; +} + +const std::vector& DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes() { + static std::vector all_IEEE_float_tensor_except_half_types = + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + + return all_IEEE_float_tensor_except_half_types; +} + +const std::vector& DataTypeImpl::AllIEEEFloatTensorTypes() { + static std::vector all_IEEE_float_tensor_types = + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + + return all_IEEE_float_tensor_types; +} + const std::vector& DataTypeImpl::AllFixedSizeTensorTypes() { static std::vector all_fixed_size_tensor_types = {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/framework/error_code.cc b/onnxruntime/core/framework/error_code.cc index 2cf11f4e1de8e..c727b7464f3ac 100644 --- a/onnxruntime/core/framework/error_code.cc +++ b/onnxruntime/core/framework/error_code.cc @@ -12,11 +12,12 @@ struct OrtStatus { char msg[1]; // a null-terminated string }; -ORT_API(OrtStatus*, OrtCreateStatus, OrtErrorCode code, _In_ const char* msg) { +//Even we say it may not return NULL, indeed it may. +ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtCreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION { assert(!(code == 0 && msg != nullptr)); size_t clen = strlen(msg); OrtStatus* p = reinterpret_cast(::malloc(sizeof(OrtStatus) + clen)); - if (p == nullptr) return nullptr; // OOM + if (p == nullptr) return nullptr; // OOM. What we can do here? abort()? p->code = code; memcpy(p->msg, msg, clen); p->msg[clen] = '\0'; diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index c44bb3e0497a3..59a025a61711f 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -22,11 +22,15 @@ IExecutionFrame::IExecutionFrame(const std::vector& feed_mlvalue_idxs, cons const std::unordered_map& initializers, const std::vector& fetch_mlvalue_idxs, const std::vector& fetches, const OrtValueNameIdxMap& ort_value_idx_map, const NodeIndexInfo& node_index_info) - : node_index_info_{node_index_info}, fetch_mlvalue_idxs_{fetch_mlvalue_idxs} { + : node_index_info_{node_index_info}, + all_values_size_{static_cast(ort_value_idx_map.MaxIdx()) + 1}, + fetch_mlvalue_idxs_{fetch_mlvalue_idxs} { ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); + ORT_ENFORCE(node_index_info_.GetMaxMLValueIdx() == ort_value_idx_map.MaxIdx(), + "node_index_info and ort_value_idx_map are out of sync and cannot be used"); - Init(feed_mlvalue_idxs, feeds, initializers, fetches, ort_value_idx_map); + Init(feed_mlvalue_idxs, feeds, initializers, fetches); } IExecutionFrame::~IExecutionFrame() = default; @@ -79,7 +83,7 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtAllocatorInfo& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { - if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_.size()) { + if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); } @@ -95,19 +99,16 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { } int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const { + // the validity of index is checked by GetMLValueIndex int ort_value_idx = node_index_info_.GetMLValueIndex(index); - ORT_ENFORCE(ort_value_idx == NodeIndexInfo::kInvalidEntry || - (ort_value_idx >= 0 && static_cast(ort_value_idx) < all_values_.size())); - return ort_value_idx; } void IExecutionFrame::Init(const std::vector& feed_mlvalue_idxs, const std::vector& feeds, const std::unordered_map& initializers, - const std::vector& fetches, - const OrtValueNameIdxMap& ort_value_idx_map) { + const std::vector& fetches) { // 1. resize the all_value_ vector - all_values_.resize(ort_value_idx_map.MaxIdx() + 1); + all_values_.resize(all_values_size_); // 2. Handle non-empty output vector if (!fetches.empty()) { @@ -402,54 +403,54 @@ Status ExecutionFrame::AllocateAsPerAllocationPlan(OrtValue& ort_value, int ort_ const auto& alloc_info = per_alloc_plan.location; const auto* ml_type = per_alloc_plan.value_type; - if (ml_type == nullptr) + if (ml_type == nullptr) { return Status( ONNXRUNTIME, INVALID_ARGUMENT, "Tried to allocate without valid type information, ort_value index=" + std::to_string(ort_value_index)); - - if (ml_type->IsSparseTensorType()) { - return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info), - *shape, nnz, per_alloc_plan.create_fence_if_async, session_state_); - } - if (!ml_type->IsTensorType()) { - return AllocateTraditionalMLValue(ort_value, *static_cast(ml_type)); } - ORT_ENFORCE(shape, "Allocation of tensor types requires a shape."); + if (ml_type->IsTensorType()) { + ORT_ENFORCE(shape, "Allocation of tensor types requires a shape."); - // tensors - const auto* ml_data_type = static_cast(ml_type)->GetElementType(); + // tensors + const auto* ml_data_type = static_cast(ml_type)->GetElementType(); - AllocKind alloc_kind = per_alloc_plan.alloc_kind; - switch (alloc_kind) { - // Right now for kAllocate and kAllocateOutput we are using same approach. - // In the future we may want to have different way to handle it. - case AllocKind::kAllocateOutput: - case AllocKind::kAllocate: { - ORT_RETURN_IF_ERROR(AllocateMLValueTensorSelfOwnBuffer(ort_value, ort_value_index, ml_data_type, alloc_info, - *shape, per_alloc_plan.create_fence_if_async)); - break; - } - case AllocKind::kReuse: { - int reuse_mlvalue_index = per_alloc_plan.reused_buffer; - ORT_RETURN_IF_ERROR(AllocateMLValueTensorPreAllocateBuffer( - ort_value, reuse_mlvalue_index, ml_data_type, alloc_info, *shape, per_alloc_plan.create_fence_if_async)); - break; - } - case AllocKind::kShare: { - int reuse_mlvalue_index = per_alloc_plan.reused_buffer; - // copy at the OrtValue level so the shared_ptr for the data is shared between the two OrtValue instances - ort_value = GetMutableMLValue(reuse_mlvalue_index); - break; - } - default: { - std::ostringstream ostr; - ostr << "Invalid allocation kind: " << static_cast::type>(alloc_kind); - return Status(ONNXRUNTIME, FAIL, ostr.str()); + AllocKind alloc_kind = per_alloc_plan.alloc_kind; + switch (alloc_kind) { + // Right now for kAllocate and kAllocateOutput we are using same approach. + // In the future we may want to have different way to handle it. + case AllocKind::kAllocateOutput: + case AllocKind::kAllocate: { + ORT_RETURN_IF_ERROR(AllocateMLValueTensorSelfOwnBuffer(ort_value, ort_value_index, ml_data_type, alloc_info, + *shape, per_alloc_plan.create_fence_if_async)); + break; + } + case AllocKind::kReuse: { + int reuse_mlvalue_index = per_alloc_plan.reused_buffer; + ORT_RETURN_IF_ERROR(AllocateMLValueTensorPreAllocateBuffer( + ort_value, reuse_mlvalue_index, ml_data_type, alloc_info, *shape, per_alloc_plan.create_fence_if_async)); + break; + } + case AllocKind::kShare: { + int reuse_mlvalue_index = per_alloc_plan.reused_buffer; + // copy at the OrtValue level so the shared_ptr for the data is shared between the two OrtValue instances + ort_value = GetMutableMLValue(reuse_mlvalue_index); + break; + } + default: { + std::ostringstream ostr; + ostr << "Invalid allocation kind: " << static_cast::type>(alloc_kind); + return Status(ONNXRUNTIME, FAIL, ostr.str()); + } } - } - return Status::OK(); + return Status::OK(); + } else if (ml_type->IsSparseTensorType()) { + return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info), + *shape, nnz, per_alloc_plan.create_fence_if_async, session_state_); + } else { + return AllocateTraditionalMLValue(ort_value, *static_cast(ml_type)); + } } AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtAllocatorInfo& info) const { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index c99979edb7eba..06d042de3bd20 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -74,10 +74,10 @@ class IExecutionFrame { void Init(const std::vector& feed_mlvalue_idxs, const std::vector& feeds, const std::unordered_map& initializers, - const std::vector& fetches, const OrtValueNameIdxMap& ort_value_idx_map); + const std::vector& fetches); const OrtValue& GetMLValue(int ort_value_index) const { - ORT_ENFORCE(ort_value_index >= 0 && static_cast(ort_value_index) < all_values_.size()); + ORT_ENFORCE(ort_value_index >= 0 && static_cast(ort_value_index) < all_values_size_); return all_values_[ort_value_index]; } @@ -91,6 +91,9 @@ class IExecutionFrame { // Input and Output values are passed in by executors std::vector all_values_; + // perf optimization to avoid calling all_values_.size() repeatedly as the size is fixed once constructed + const size_t all_values_size_; + const std::vector fetch_mlvalue_idxs_; }; diff --git a/onnxruntime/core/framework/feeds_fetches_manager.h b/onnxruntime/core/framework/feeds_fetches_manager.h index 000eaa504176f..d646c82ab23d4 100644 --- a/onnxruntime/core/framework/feeds_fetches_manager.h +++ b/onnxruntime/core/framework/feeds_fetches_manager.h @@ -48,9 +48,8 @@ struct FeedsFetchesInfo { class FeedsFetchesManager { public: struct MLValueCopyInfo { - int allocation_device_id = 0; + OrtDevice target_device; const IExecutionProvider* allocation_provider = nullptr; - const IExecutionProvider* copy_provider = nullptr; }; static Status Create(const std::vector& feed_names, const std::vector& output_names, diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 5b0cba6c3b0d8..fe53971656932 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/framework/graph_partitioner.h" - #include "core/framework/kernel_registry_manager.h" #include "core/graph/function.h" #include "core/graph/graph_viewer.h" @@ -176,10 +175,6 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f //prepare the func kernel KernelDefBuilder builder; BuildFusedKernelDef(builder, *node); - if (node->GetExecutionProviderType() == onnxruntime::kTensorrtExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNnapiExecutionProvider) { - builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput); - builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput); - } ORT_RETURN_IF_ERROR(fused_kernel_registry->Register( builder, static_cast([](const OpKernelInfo& info) -> OpKernel* { return new FunctionKernel(info); }))); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 5fe803b368022..203bc7c21e45f 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -14,11 +14,21 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, const SessionState& session_state, /*out*/ std::unique_ptr& op_kernel) const { + auto create_error_message = [&node](const std::string& error) { + std::ostringstream errormsg; + errormsg << error << node.OpType(); + if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; + if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")"; + return errormsg.str(); + }; + const std::string& ptype = node.GetExecutionProviderType(); if (ptype.empty()) { return Status(ONNXRUNTIME, FAIL, - "The node is not placed on any Execution Provider, therefore, can't find a suitable kernel for it"); + create_error_message("The node is not placed on any Execution Provider, " + "therefore, can't find a suitable kernel for ")); } + Status status; { for (auto& registry : custom_kernel_registries_) { @@ -41,11 +51,7 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, } } - std::ostringstream errormsg; - errormsg << "Failed to find kernel for " << node.OpType(); - if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")"; - return Status(ONNXRUNTIME, FAIL, errormsg.str()); + return Status(ONNXRUNTIME, FAIL, create_error_message("Failed to find kernel for ")); } Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& execution_providers) { diff --git a/onnxruntime/core/framework/mem_pattern.h b/onnxruntime/core/framework/mem_pattern.h index 57d9e99360b13..2aa1e3cad32eb 100644 --- a/onnxruntime/core/framework/mem_pattern.h +++ b/onnxruntime/core/framework/mem_pattern.h @@ -20,11 +20,11 @@ class MemoryPattern { public: MemoryPattern() = default; - MemoryPattern(MemoryPattern&& rhs) + MemoryPattern(MemoryPattern&& rhs) noexcept : patterns_{std::move(rhs.patterns_)}, peak_size_{std::move(rhs.peak_size_)} {} - MemoryPattern& operator=(MemoryPattern&& rhs) { + MemoryPattern& operator=(MemoryPattern&& rhs) noexcept { patterns_ = std::move(rhs.patterns_); peak_size_ = std::move(rhs.peak_size_); return *this; diff --git a/onnxruntime/core/framework/node_index_info.cc b/onnxruntime/core/framework/node_index_info.cc index 7931825e7fd7c..d77a72cabc909 100644 --- a/onnxruntime/core/framework/node_index_info.cc +++ b/onnxruntime/core/framework/node_index_info.cc @@ -69,6 +69,10 @@ void NodeIndexInfo::Init(const TValidNodes& nodes, NodeIndex max_node_index, // init all to kInvalidEntry node_offsets_.resize(GetNodeOffsetsIndex(max_node_index), kInvalidEntry); node_values_.resize(total_def_count, kInvalidEntry); + + node_offsets_size_ = node_offsets_.size(); + node_values_size_ = node_values_.size(); + int cur_idx = 0; for (auto& node : nodes) { diff --git a/onnxruntime/core/framework/node_index_info.h b/onnxruntime/core/framework/node_index_info.h index afd74a1874900..19b4a202f578f 100644 --- a/onnxruntime/core/framework/node_index_info.h +++ b/onnxruntime/core/framework/node_index_info.h @@ -31,14 +31,14 @@ class NodeIndexInfo final { // Returns kInvalidEntry if the Node with the given node_index did not exist when the NodeIndexInfo was created. int GetNodeOffset(NodeIndex node_index) const { auto node_offsets_index = GetNodeOffsetsIndex(node_index); - ORT_ENFORCE(node_offsets_index < node_offsets_.size()); + ORT_ENFORCE(node_offsets_index < node_offsets_size_); return node_offsets_[node_offsets_index]; } // Get the ort_value index value. // Returns kInvalidEntry for optional inputs/outputs that do not exist in this graph. int GetMLValueIndex(int offset) const { - ORT_ENFORCE(offset >= 0 && static_cast(offset) < node_values_.size()); + ORT_ENFORCE(offset >= 0 && static_cast(offset) < node_values_size_); return node_values_[offset]; } @@ -63,5 +63,9 @@ class NodeIndexInfo final { std::vector node_offsets_; const int max_mlvalue_idx_; + + // perf optimization to avoid calls to size() on node_values_ and node_offsets_ as they don't change + size_t node_values_size_; + size_t node_offsets_size_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index 02515ba39a160..b837356504d36 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -5,6 +5,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" +#include "core/session/onnxruntime_c_api.h" // onnxruntime internal OpKernelContext derived class to provide additional // APIs that aren't desirable to add to the public OpKernelContext API @@ -57,7 +58,8 @@ class OpKernelContextInternal : public OpKernelContext { const bool& GetTerminateFlag() const noexcept { return terminate_flag_; } - const onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return session_state_.GetThreadPool(); } + _Ret_maybenull_ const onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return session_state_.GetThreadPool(); } + _Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() { return session_state_.GetThreadPool(); } private: const SessionState& session_state_; diff --git a/onnxruntime/core/framework/parallel_executor.cc b/onnxruntime/core/framework/parallel_executor.cc index 72ee80cd421ee..ff33f93eab6c4 100644 --- a/onnxruntime/core/framework/parallel_executor.cc +++ b/onnxruntime/core/framework/parallel_executor.cc @@ -122,6 +122,7 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index, TimePoint sync_time_begin; TimePoint kernel_begin_time; const bool f_profiler_enabled = session_state.Profiler().IsEnabled(); + const SequentialExecutionPlan& exec_plan = *session_state.GetExecutionPlan(); // Avoid context switching if possible. while (keep_running) { @@ -149,33 +150,34 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index, } // sync before compute int queue_id = p_op_kernel->KernelDef().ExecQueueId(); - - for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { - Fence_t fence = op_kernel_context.InputFence(input_index); - if (fence) { - auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); - if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { - execution_provider_type = kCpuExecutionProvider; + if (exec_plan.NodeHasFence(node_index)) { + for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { + Fence_t fence = op_kernel_context.InputFence(input_index); + if (fence) { + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - } - for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { - Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); - if (fence) { - auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); - if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { - execution_provider_type = kCpuExecutionProvider; + for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { + Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); + if (fence) { + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - } - for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { - Fence_t fence = op_kernel_context.OutputFence(output_index); - if (fence) { - fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { + Fence_t fence = op_kernel_context.OutputFence(output_index); + if (fence) { + fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + } } } @@ -209,32 +211,36 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index, sync_time_begin = session_state.Profiler().StartTime(); } // sync after compute for outputs - for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { - Fence_t fence = op_kernel_context.InputFence(input_index); - if (fence) { - fence->AfterUsedAsInput(queue_id); + if (exec_plan.NodeHasFence(node_index)) { + for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { + Fence_t fence = op_kernel_context.InputFence(input_index); + if (fence) { + fence->AfterUsedAsInput(queue_id); + } } - } - for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { - Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); - if (fence) { - fence->AfterUsedAsInput(queue_id); + for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { + Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); + if (fence) { + fence->AfterUsedAsInput(queue_id); + } } - } - for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { - Fence_t fence = op_kernel_context.OutputFence(output_index); - if (fence) { - fence->AfterUsedAsOutput(queue_id); + for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { + Fence_t fence = op_kernel_context.OutputFence(output_index); + if (fence) { + fence->AfterUsedAsOutput(queue_id); + } } } + if (f_profiler_enabled) { session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, p_op_kernel->Node().Name() + "_fence_after", sync_time_begin, {{"op_name", p_op_kernel->KernelDef().OpName()}}); } + //std::cout << "Run async node finish: " << p_node_index << std::endl; keep_running = false; diff --git a/onnxruntime/core/framework/parallel_executor.h b/onnxruntime/core/framework/parallel_executor.h index 5f34309937bac..74d3fbce3d8d4 100644 --- a/onnxruntime/core/framework/parallel_executor.h +++ b/onnxruntime/core/framework/parallel_executor.h @@ -21,7 +21,6 @@ class ExecutionFrame; class ParallelExecutor : public IExecutor { public: - ParallelExecutor(const bool& terminate_flag = false) : terminate_flag_{terminate_flag} {} ParallelExecutor(const SessionState& session_state, const bool& terminate_flag = false); common::Status Execute(const SessionState& session_state, const std::vector& feed_mlvalue_idxs, diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index 079be56fc5ae4..640c610841774 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -17,6 +17,11 @@ ORT_API_STATUS_IMPL(OrtRunOptionsSetRunLogVerbosityLevel, _In_ OrtRunOptions* op return nullptr; } +ORT_API_STATUS_IMPL(OrtRunOptionsSetRunLogSeverityLevel, _In_ OrtRunOptions* options, int value) { + options->run_log_severity_level = value; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtRunOptionsSetRunTag, _In_ OrtRunOptions* options, _In_ const char* run_tag) { if (run_tag) options->run_tag = run_tag; @@ -28,6 +33,11 @@ ORT_API_STATUS_IMPL(OrtRunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptio return nullptr; } +ORT_API_STATUS_IMPL(OrtRunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, int* out) { + *out = options->run_log_severity_level; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtRunOptionsGetRunTag, _In_ const OrtRunOptions* options, const char** out) { *out = options->run_tag.c_str(); return nullptr; diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index 24ed345965cbc..5c6827966dd41 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -66,6 +66,9 @@ struct SequentialExecutionPlan : public ExecutionPlanBase { // Execution_plan: represents the nodes in the sequential order to be executed std::vector execution_plan; + // Records whether a given node has fence on its input or output, key is node index. + std::vector node_has_fence; + // to_be_freed: vector elements represent indices of ml-values to be freed (as described above) std::vector to_be_freed; @@ -84,6 +87,12 @@ struct SequentialExecutionPlan : public ExecutionPlanBase { } return locations; } + + // Whether a given node needs fence check or not. + bool NodeHasFence(onnxruntime::NodeIndex node_index) const { + return node_has_fence[node_index]; + } + }; // Output details of an execution plan: diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index bd45bbfdc0b01..0f08e8613cc1a 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -71,32 +71,34 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: // sync before compute int queue_id = p_op_kernel->KernelDef().ExecQueueId(); - for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { - Fence_t fence = op_kernel_context.InputFence(input_index); - if (fence) { - auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); - if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { - execution_provider_type = kCpuExecutionProvider; + if (seq_exec_plan.NodeHasFence(node_index)) { + for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { + Fence_t fence = op_kernel_context.InputFence(input_index); + if (fence) { + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - } - for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { - Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); - if (fence) { - auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); - if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { - execution_provider_type = kCpuExecutionProvider; + for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { + Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); + if (fence) { + auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType(); + if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) { + execution_provider_type = kCpuExecutionProvider; + } + fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - fence->BeforeUsingAsInput(execution_provider_type, queue_id); } - } - for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { - Fence_t fence = op_kernel_context.OutputFence(output_index); - if (fence) { - fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { + Fence_t fence = op_kernel_context.OutputFence(output_index); + if (fence) { + fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id); + } } } @@ -138,24 +140,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: } // sync after compute for outputs - for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { - Fence_t fence = op_kernel_context.InputFence(input_index); - if (fence) { - fence->AfterUsedAsInput(queue_id); + if (seq_exec_plan.NodeHasFence(node_index)) { + for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { + Fence_t fence = op_kernel_context.InputFence(input_index); + if (fence) { + fence->AfterUsedAsInput(queue_id); + } } - } - for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { - Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); - if (fence) { - fence->AfterUsedAsInput(queue_id); + for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) { + Fence_t fence = op_kernel_context.ImplicitInputFence(input_index); + if (fence) { + fence->AfterUsedAsInput(queue_id); + } } - } - for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { - Fence_t fence = op_kernel_context.OutputFence(output_index); - if (fence) { - fence->AfterUsedAsOutput(queue_id); + for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) { + Fence_t fence = op_kernel_context.OutputFence(output_index); + if (fence) { + fence->AfterUsedAsOutput(queue_id); + } } } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a6fe46be955ed..fbf0f50d37253 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -11,23 +11,96 @@ #include "core/framework/utils.h" using namespace ::onnxruntime::common; -namespace onnxruntime { -void SessionState::SetGraphViewer(std::unique_ptr graph_viewer) { - ORT_ENFORCE(nullptr != graph_viewer); - graph_viewer_ = std::move(graph_viewer); -} +namespace onnxruntime { const GraphViewer* SessionState::GetGraphViewer() const { return graph_viewer_.get(); } +Status SessionState::SetGraph(const Graph& graph) { + graph_viewer_ = std::make_unique(graph); + auto& logger = Logger(); + // use graph_viewer_ to initialize ort_value_name_idx_map_ + LOGS(logger, INFO) << "SaveMLValueNameIndexMapping"; + int idx = 0; + + // we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry + for (const auto* input_def : graph_viewer_->GetInputsIncludingInitializers()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added graph_viewer_ input with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + + for (auto& node : graph_viewer_->Nodes()) { + // build the OrtValue->index map + for (const auto* input_def : node.InputDefs()) { + if (input_def->Exists()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + + for (const auto* input_def : node.ImplicitInputDefs()) { + if (input_def->Exists()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + + for (const auto* output_def : node.OutputDefs()) { + if (output_def->Exists()) { + ort_value_name_idx_map_.Add(output_def->Name()); + VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + } + + // allocate OrtValue for graph outputs when coming from initializers + for (const auto& output : graph_viewer_->GetOutputs()) { + if (output->Exists()) { + idx = ort_value_name_idx_map_.Add(output->Name()); + VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx; + } + } -const OpKernel* SessionState::GetKernel(NodeIndex node_id) const { - auto kernel = session_kernels_.find(node_id); - return (kernel != session_kernels_.cend()) ? kernel->second.get() : nullptr; + LOGS(logger, INFO) << "Done saving OrtValue mappings."; + return Status::OK(); } -void SessionState::AddKernel(onnxruntime::NodeIndex node_id, std::unique_ptr p_kernel) { - // assumes vector is already resize()'ed to the number of nodes in the graph - session_kernels_[node_id] = std::move(p_kernel); +Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) { + const GraphNodes& nodes = graph_viewer_->Nodes(); + if (!nodes.empty()) { + size_t max_nodeid = 0; + for (auto& node : graph_viewer_->Nodes()) { + max_nodeid = std::max(max_nodeid, node.Index()); + } + session_kernels_.clear(); + session_kernels_.resize(max_nodeid + 1, nullptr); + for (auto& node : graph_viewer_->Nodes()) { + // construct and save the kernels + std::unique_ptr op_kernel; + onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); + + const IExecutionProvider* exec_provider = nullptr; + if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), + " as there's no execution provider allocated."); + } + + common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel); + if (!status.IsOK()) { + return common::Status( + status.Category(), status.Code(), + MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); + } + assert(session_kernels_[node.Index()] == nullptr); + // assumes vector is already resize()'ed to the number of nodes in the graph + session_kernels_[node.Index()] = op_kernel.release(); + } + } + node_index_info_ = std::make_unique(*graph_viewer_, ort_value_name_idx_map_); + return Status::OK(); } void SessionState::SetExecutionPlan(std::unique_ptr p_seq_exec_plan) { @@ -38,7 +111,6 @@ const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant) { - ORT_ENFORCE(ort_value_index >= 0 && ort_value_index <= ort_value_name_idx_map_.MaxIdx()); auto p = initialized_tensors_.insert({ort_value_index, ort_value}); if (!p.second) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index, @@ -55,9 +127,7 @@ Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& o return Status::OK(); } -const std::unordered_map& SessionState::GetInitializedTensors() const { - return initialized_tensors_; -} +const std::unordered_map& SessionState::GetInitializedTensors() const { return initialized_tensors_; } const std::unordered_map& SessionState::GetConstantInitializedTensors() const { return constant_initialized_tensors_; @@ -86,7 +156,8 @@ static int64_t CalculateMemoryPatternsKey(const std::vector>& input_shapes) const { +const MemoryPatternGroup* SessionState::GetMemoryPatternGroup( + const std::vector>& input_shapes) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -96,8 +167,9 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< return it->second.get(); } -Status SessionState::UpdateMemoryPatternGroupCache(const std::vector>& input_shapes, - std::unique_ptr mem_patterns) const { +Status SessionState::UpdateMemoryPatternGroupCache( + const std::vector>& input_shapes, + std::unique_ptr mem_patterns) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -109,9 +181,7 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vectorName(), " (", current_provider, - ") and node ", node_info.p_node->Name(), " (", new_provider, ")."); + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "Using an input in multiple nodes on different devices is not supported currently. Input:", input_name, + " is used by node ", existing_entry.p_node->Name(), " (", current_device->ToString(), ") and node ", + node_info.p_node->Name(), " (", new_device->ToString(), ")."); } } } @@ -178,16 +249,15 @@ const SessionState::NameNodeInfoMapType& SessionState::GetOutputNodeInfoMap() co return output_names_to_nodeinfo_mapping_; } -void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, - const std::string& attribute_name, +void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, std::unique_ptr session_state) { auto entry = subgraph_session_states_.find(index); // make sure this is new. internal logic error if it is not so using ORT_ENFORCE. if (entry != subgraph_session_states_.cend()) { const auto& existing_entries = entry->second; - ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), - "Entry exists in node ", index, " for attribute ", attribute_name); + ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index, + " for attribute ", attribute_name); } subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state))); @@ -215,19 +285,8 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex return const_cast(this)->GetMutableSubgraphSessionState(index, attribute_name); } -void SessionState::CalculateNodeIndexInfo() { - ORT_ENFORCE(graph_viewer_); - node_index_info_ = std::make_unique(*graph_viewer_, ort_value_name_idx_map_); - - for (auto& node_to_map_pair : subgraph_session_states_) { - for (auto& attr_name_to_subgraph : node_to_map_pair.second) { - attr_name_to_subgraph.second->CalculateNodeIndexInfo(); - } - } -} - const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { - ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo."); + ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo."); return *node_index_info_; } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index dfec27108257a..b9e4c08900ddc 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -20,7 +20,7 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/mem_pattern.h" #include "core/framework/ml_value.h" -#include "core/common/callback.h" +#include "core/framework/callback.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/node_index_info.h" #include "core/graph/graph_viewer.h" @@ -40,33 +40,41 @@ struct MemoryPatternGroup; * SessionState should be modified by the inference session class only. * It is supposed to be passed by const-ref only to all the executors. * This class owns all the initializers. + * Brief usage: + * SessionState s(...); + * for(...) s.AddInitializedTensor(...); + * s.SetGraphAndCreateKernels(...); + * Then you can use: + * s.GetKernel(...); */ class SessionState { public: - SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern) - : execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern) {} + SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern, + concurrency::ThreadPool* thread_pool) + : execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern), thread_pool_(thread_pool) {} ~SessionState() { + for (auto* p : session_kernels_) { + delete p; + } for (auto& kvp : deleter_for_initialized_tensors_) { kvp.second.f(kvp.second.param); } } // Graph viewer. - void SetGraphViewer(std::unique_ptr graph_viewer); const GraphViewer* GetGraphViewer() const; // kernels // Get kernel for specified node. // It should called right before graph execution only. - const OpKernel* GetKernel(NodeIndex node_id) const; - - void AddKernel(NodeIndex node_id, std::unique_ptr p_kernel); + const OpKernel* GetKernel(size_t node_id) const { + return (node_id < session_kernels_.size()) ? session_kernels_[node_id] : nullptr; + } const ExecutionProviders& GetExecutionProviders() const noexcept { return execution_providers_; } const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; } - OrtValueNameIdxMap& GetOrtValueNameIdxMap() noexcept { return ort_value_name_idx_map_; } // initialized tensors /** @@ -77,6 +85,12 @@ class SessionState { */ Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant); + Status SetGraph(const Graph& graph); + Status CreateKernels(const KernelRegistryManager& custom_registry_manager); + Status SetGraphAndCreateKernels(const Graph& graph, const KernelRegistryManager& custom_registry_manager) { + ORT_RETURN_IF_ERROR(SetGraph(graph)); + return CreateKernels(custom_registry_manager); + } /** * Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the * execution frame to setup the appropriate OrtValue vectors. @@ -85,8 +99,8 @@ class SessionState { const std::unordered_map& GetInitializedTensors() const; /** - * Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant - * and cannot be overridden at runtime. + * Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant + * and cannot be overridden at runtime. * The lifetime of returned OrtValues are limited by this SessionState object. */ const std::unordered_map& GetConstantInitializedTensors() const; @@ -96,12 +110,12 @@ class SessionState { const SequentialExecutionPlan* GetExecutionPlan() const; /** - Set the logger to use for this session. + Set the logger to use for this session. */ SessionState& SetLogger(const logging::Logger& logger); /** - Get the logger for this session. + Get the logger for this session. Falls back to returning Logging::LoggingManager::DefaultLogger if SetLogger has not been called. */ const logging::Logger& Logger() const; @@ -120,10 +134,11 @@ class SessionState { /** Get cached memory pattern based on input shapes */ - const MemoryPatternGroup* GetMemoryPatternGroup(const std::vector>& input_shapes) const; + const MemoryPatternGroup* GetMemoryPatternGroup( + const std::vector>& input_shapes) const; /** - Set generated memory pattern with a given input shapes. + Set generated memory pattern with a given input shapes. Const as it's an internal cache update only. */ Status UpdateMemoryPatternGroupCache(const std::vector>& input_shape, @@ -141,17 +156,15 @@ class SessionState { * \param p_node0 Nullable * \param kci0 Nullable */ - NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0) - : index(index0), - p_node(p_node0), - kci(kci0) { - } + NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0) + : index(index0), p_node(p_node0), kci(kci0), device(&device0) {} size_t index; // Nullable const onnxruntime::Node* p_node = nullptr; // Nullable const KernelCreateInfo* kci = nullptr; + const OrtDevice* device = nullptr; }; using NameNodeInfoMapType = std::unordered_map>; @@ -174,8 +187,7 @@ class SessionState { SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); - onnxruntime::concurrency::ThreadPool* GetThreadPool() const { return thread_pool_; } - void SetThreadPool(onnxruntime::concurrency::ThreadPool* p_pool) { thread_pool_ = p_pool; } + concurrency::ThreadPool* GetThreadPool() const { return thread_pool_; } bool ExportDll() const { return export_fused_dll_; } void SetExportDllFlag(bool flag) { export_fused_dll_ = flag; } @@ -187,7 +199,6 @@ class SessionState { void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; } std::vector& GetMutableWeightsBuffers() { return weights_buffers_; } - void CalculateNodeIndexInfo(); const NodeIndexInfo& GetNodeIndexInfo() const; private: @@ -195,7 +206,7 @@ class SessionState { // cache of the constructed kernels to avoid spending construction // time per executor - std::unordered_map> session_kernels_; + std::vector session_kernels_; std::unique_ptr graph_viewer_; const ExecutionProviders& execution_providers_; // owned by InferenceSession @@ -231,7 +242,8 @@ class SessionState { std::unordered_map>>; SubgraphSessionStateMap subgraph_session_states_; - onnxruntime::concurrency::ThreadPool* thread_pool_ = nullptr; + // It could be NULL + concurrency::ThreadPool* const thread_pool_; bool export_fused_dll_ = false; FuncManager fused_funcs_mgr_; diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 3f4777d8608d0..18589de82679a 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -27,9 +27,6 @@ namespace onnxruntime { -static common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, - OrtValueNameIdxMap& ort_value_name_idx_map, - const logging::Logger& logger); // T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status' template @@ -40,11 +37,6 @@ static common::Status SaveInitializedTensors(const Env& env, const std::basic_st const logging::Logger& logger, const DataTransferManager& data_transfer_mgr); -static common::Status SaveKernels(const ExecutionProviders& execution_providers, - SessionState& session_state, - const KernelRegistryManager& custom_registry_manager, - const logging::Logger& logger); - static common::Status SaveInputOutputNamesToNodeMapping( const onnxruntime::Graph& graph, const KernelRegistryManager& custom_registry_manager, @@ -68,11 +60,11 @@ common::Status SessionStateInitializer::CreatePlan( const Node* parent_node, const ConstPointerContainer>* outer_scope_node_args, bool enable_sequential_execution) { - auto graph_viewer = std::make_unique(graph_); + session_state_.SetGraph(graph_); + const GraphViewer* graph_viewer = session_state_.GetGraphViewer(); // populate the SessionState OrtValueNameIdxMap - auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap(); - ORT_RETURN_IF_ERROR(SaveMLValueNameIndexMapping(*graph_viewer, ort_value_name_idx_map, logger_)); + const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap(); // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. std::vector valid_outer_scope_node_args; @@ -92,17 +84,10 @@ common::Status SessionStateInitializer::CreatePlan( execution_providers_, kernel_registry_manager_, ort_value_name_idx_map, context, exec_plan)); session_state_.SetExecutionPlan(std::move(exec_plan)); - session_state_.SetGraphViewer(std::move(graph_viewer)); - return Status::OK(); -} - -common::Status SessionStateInitializer::InitializeAndSave( - const ConstPointerContainer>* implicit_inputs) { const auto* exec_plan_ptr = session_state_.GetExecutionPlan(); ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first."); - const auto& ort_value_name_idx_map{session_state_.GetOrtValueNameIdxMap()}; std::unique_ptr tensor_allocator_(ITensorAllocator::Create( enable_mem_pattern_, *exec_plan_ptr, execution_providers_, session_state_.GetMutableWeightsBuffers())); @@ -119,64 +104,12 @@ common::Status SessionStateInitializer::InitializeAndSave( // TODO: make it better graph_.CleanAllInitializedTensors(); - ORT_RETURN_IF_ERROR(SaveKernels(execution_providers_, session_state_, kernel_registry_manager_, logger_)); - ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, - implicit_inputs)); - + ORT_RETURN_IF_ERROR(session_state_.CreateKernels(kernel_registry_manager_)); + ORT_RETURN_IF_ERROR( + SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, outer_scope_node_args)); return Status::OK(); } -// Build the OrtValue name->idx mapping -common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, OrtValueNameIdxMap& ort_value_name_idx_map, - const logging::Logger& logger) { - LOGS(logger, INFO) << "SaveMLValueNameIndexMapping"; - int idx = 0; - - // we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry - for (const auto* input_def : graph_viewer.GetInputsIncludingInitializers()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added graph_viewer input with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - - for (auto& node : graph_viewer.Nodes()) { - // build the OrtValue->index map - for (const auto* input_def : node.InputDefs()) { - if (input_def->Exists()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - - for (const auto* input_def : node.ImplicitInputDefs()) { - if (input_def->Exists()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - - for (const auto* output_def : node.OutputDefs()) { - if (output_def->Exists()) { - ort_value_name_idx_map.Add(output_def->Name()); - VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - } - - // allocate OrtValue for graph outputs when coming from initializers - for (const auto& output : graph_viewer.GetOutputs()) { - if (output->Exists()) { - idx = ort_value_name_idx_map.Add(output->Name()); - VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx; - } - } - - LOGS(logger, INFO) << "Done saving OrtValue mappings."; - return Status::OK(); -} static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, @@ -292,46 +225,6 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string& op_kernel) { - onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); - - const IExecutionProvider* exec_provider = nullptr; - if (exec_provider_name.empty() || (exec_provider = execution_providers.Get(exec_provider_name)) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), - " as there's no execution provider allocated."); - } - - common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, session_state, op_kernel); - if (!status.IsOK()) { - return common::Status( - status.Category(), status.Code(), - MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); - } - - return status; -} - -common::Status SaveKernels(const ExecutionProviders& execution_providers, - SessionState& session_state, - const KernelRegistryManager& custom_registry_manager, - const logging::Logger& logger) { - LOGS(logger, INFO) << "Saving kernels."; - - for (auto& node : session_state.GetGraphViewer()->Nodes()) { - // construct and save the kernels - std::unique_ptr op_kernel; - ORT_RETURN_IF_ERROR(CreateOpKernel(node, execution_providers, session_state, custom_registry_manager, op_kernel)); - session_state.AddKernel(node.Index(), std::move(op_kernel)); - } - - LOGS(logger, INFO) << "Done saving kernels."; - - return Status::OK(); -} - template // T is container of const NodeArg* or NodeArg* static bool IsArgNameInInputsOutputs(const std::string& name, const T& graph_args) { @@ -351,6 +244,8 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph if (implicit_inputs && implicit_inputs->empty()) { implicit_inputs = nullptr; } + const auto* exec_plan = session_state.GetExecutionPlan(); + const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); for (auto& node : graph.Nodes()) { // note that KernelCreateInfo may not exist for custom kernel @@ -365,7 +260,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph return Status::OK(); } - SessionState::NodeInfo node_info(index, &node, kci); + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); + const auto& device = exec_plan->GetLocation(arg_index).device; + + SessionState::NodeInfo node_info(index, &node, kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) { ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info)); @@ -397,8 +296,13 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph // copy to/from CPU to go through the control flow nodes where possible/applicable. // the processing for the subgraph where the implicit input is consumed will do the real check on whether any // copy to a different device is required - SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci); for (const auto& input_def : node_implicit_inputs) { + int arg_index; + //Question: the implicit input may not be found in this session state name to id map, but in parent session state name to id map. + //@Scott + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); + auto& device = exec_plan->GetLocation(arg_index).device; + SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info)); } } @@ -413,7 +317,6 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph auto& input_map = session_state.GetInputNodeInfoMap(); auto end_map = input_map.cend(); - SessionState::NodeInfo empty_node_info(std::numeric_limits::max(), nullptr, nullptr); for (const auto& graph_input : graph_inputs) { const auto& name = graph_input->Name(); @@ -422,6 +325,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " << name << " is not used by any node."; + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(name, arg_index)); + auto& device = exec_plan->GetLocation(arg_index).device; + SessionState::NodeInfo empty_node_info(std::numeric_limits::max(), nullptr, nullptr, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(name, empty_node_info)); } } diff --git a/onnxruntime/core/framework/session_state_initializer.h b/onnxruntime/core/framework/session_state_initializer.h index 3634704de5e2a..8c969571c558a 100644 --- a/onnxruntime/core/framework/session_state_initializer.h +++ b/onnxruntime/core/framework/session_state_initializer.h @@ -36,14 +36,11 @@ class SessionStateInitializer { KernelRegistryManager& kernel_registry_manager); // First perform any transformations and create the execution plan - common::Status CreatePlan(const Node* parent_node, - const ConstPointerContainer>* outer_scope_node_args, + // Then initialize tensors, and save. save kernels and input/output node mappings + common::Status CreatePlan(_In_opt_ const Node* parent_node, + _In_opt_ const ConstPointerContainer>* outer_scope_node_args, bool enable_sequential_execution); - // initialize tensors, and save. save kernels and input/output node mappings - // \param implicit_inputs could be NULL - common::Status InitializeAndSave(const ConstPointerContainer>* implicit_inputs); - private: const std::basic_string& graph_loc_; onnxruntime::Graph& graph_; diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index d0085c0fe6c1a..692232a6a8abc 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -47,7 +47,7 @@ void Tensor::Init(MLDataType p_type, const TensorShape& shape, void* p_raw_data, byte_offset_ = offset; } -Tensor::Tensor(Tensor&& other) +Tensor::Tensor(Tensor&& other) noexcept : p_data_(other.p_data_), buffer_deleter_(other.buffer_deleter_), shape_(other.shape_), @@ -61,7 +61,7 @@ Tensor::Tensor(Tensor&& other) other.byte_offset_ = 0; } -Tensor& Tensor::operator=(Tensor&& other) { +Tensor& Tensor::operator=(Tensor&& other) noexcept { if (this != &other) { ReleaseBuffer(); diff --git a/onnxruntime/core/framework/tensor_shape.cc b/onnxruntime/core/framework/tensor_shape.cc index b37c2e9499c8a..72acfd7921975 100644 --- a/onnxruntime/core/framework/tensor_shape.cc +++ b/onnxruntime/core/framework/tensor_shape.cc @@ -8,16 +8,8 @@ namespace onnxruntime { -TensorShape::TensorShape(const std::vector& dims) : std::vector(dims) { -} - -TensorShape::TensorShape(std::vector&& dims) : std::vector(std::move(dims)) { -} - -TensorShape::TensorShape(const std::initializer_list& dims) : std::vector(dims) { -} - -TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count) : std::vector(dimension_count) { +TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count) + : std::vector(dimension_count) { for (size_t i = 0; i < dimension_count; ++i) { (*this)[i] = dimension_sizes[i]; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ce7fc4e91d286..768045179b5be 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -14,7 +14,7 @@ #include "core/framework/tensor.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/allocator.h" -#include "core/common/callback.h" +#include "core/framework/callback.h" #include "core/framework/data_types.h" #include "core/framework/path_lib.h" @@ -304,7 +304,7 @@ ORT_API_STATUS(OrtInitializeBufferForTensor, _In_opt_ void* input, size_t input_ */ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enum ONNXTensorElementDataType type); -static void ORT_API_CALL UnInitTensor(void* param) noexcept { +static void UnInitTensor(void* param) noexcept { UnInitializeParam* p = reinterpret_cast(param); OrtUninitializeBuffer(p->preallocated, p->preallocated_size, p->ele_type); delete p; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index b0171f25be843..cb126236d15e9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -16,16 +16,51 @@ #include "core/framework/parallel_executor.h" #include "core/framework/session_state.h" #include "core/framework/sequential_executor.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { namespace utils { +void* DefaultAlloc(size_t size) { + if (size <= 0) return nullptr; + void* p; + size_t alignment = MlasGetPreferredBufferAlignment(); +#if _MSC_VER + p = _aligned_malloc(size, alignment); + if (p == nullptr) throw std::bad_alloc(); +#elif defined(_LIBCPP_SGX_CONFIG) + p = memalign(alignment, size); + if (p == nullptr) throw std::bad_alloc(); +#else + int ret = posix_memalign(&p, alignment, size); + if (ret != 0) throw std::bad_alloc(); +#endif + return p; +} + +void DefaultFree(void* p) { +#if _MSC_VER + _aligned_free(p); +#else + free(p); +#endif +} + AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorInfo& allocator_info) { return session_state.GetExecutionProviders().GetAllocator(allocator_info); } -common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor, +bool ProviderIsCpuBased(const std::string& provider_type) { + return provider_type == onnxruntime::kCpuExecutionProvider || + provider_type == onnxruntime::kMklDnnExecutionProvider || + provider_type == onnxruntime::kNGraphExecutionProvider || + provider_type == onnxruntime::kNupharExecutionProvider || + provider_type == onnxruntime::kOpenVINOExecutionProvider || + provider_type == onnxruntime::kNnapiExecutionProvider; +} + +common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor, OrtValue& output_mlvalue) { - auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault); + auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault); if (!allocator) { return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator"); } @@ -62,20 +97,20 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, const FeedsFetchesManager::MLValueCopyInfo& copy_info, const OrtValue& source_mlvalue, OrtValue& target_mlvalue) { - if (copy_info.copy_provider == nullptr) { + if (copy_info.allocation_provider == nullptr) { target_mlvalue = source_mlvalue; - } else { - auto& source_tensor = source_mlvalue.Get(); + return Status::OK(); + } - if (!target_mlvalue.IsAllocated()) { - ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id, - source_tensor, target_mlvalue)); - } + auto& source_tensor = source_mlvalue.Get(); + if (!target_mlvalue.IsAllocated()) { + ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device, + source_tensor, target_mlvalue)); + } - Tensor* p_output_tensor = target_mlvalue.GetMutable(); + Tensor* p_output_tensor = target_mlvalue.GetMutable(); - ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); - } + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); return Status::OK(); } @@ -86,8 +121,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons FeedsFetchesManager::MLValueCopyInfo& copy_info) { needed_copy = false; - //TODO: make it configurable - const int target_device_id = 0; std::vector node_info_vec; ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec)); @@ -111,51 +144,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons break; } - auto& required_provider_type = GetNodeInputProviderType(node_info); - auto& input_tensor = orig_mlvalue.Get(); - auto& input_tensor_loc = input_tensor.Location(); - - auto* p_input_provider = exec_providers.Get(input_tensor_loc); - if (!p_input_provider) { - p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider); - ORT_ENFORCE(p_input_provider); - } - - //no copy for TRT and nGraph - if (required_provider_type == onnxruntime::kTensorrtExecutionProvider || required_provider_type == onnxruntime::kNGraphExecutionProvider) { - new_mlvalue = orig_mlvalue; - break; - } - - auto input_provider_type = p_input_provider->Type(); - if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) { - new_mlvalue = orig_mlvalue; - break; - } - - // If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy - if (required_provider_type == onnxruntime::kCpuExecutionProvider && - input_tensor_loc.mem_type == OrtMemTypeCPU) { + auto& required_device = *node_info.device; + auto& input_tensor_device = orig_mlvalue.Get().Location().device; + if (required_device == input_tensor_device) { + // No copy needed for same device. new_mlvalue = orig_mlvalue; break; } + auto& required_provider_type = GetNodeInputProviderType(node_info); auto* required_provider = exec_providers.Get(required_provider_type); - ORT_ENFORCE(required_provider); - - auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider) - ? required_provider - : p_input_provider; - - copy_info.allocation_device_id = target_device_id; + copy_info.target_device = required_device; copy_info.allocation_provider = required_provider; - copy_info.copy_provider = p_copy_provider; ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue)); needed_copy = true; - // } loop of node_info_vec } while (false); return Status::OK(); @@ -223,18 +228,16 @@ static common::Status CachedCopyInputsAcrossDevices( // Setup fetches for execution. Use any provided fetches directly if the provider matches. // If the provider doesn't match, we don't know what device the execution output may be on, so can't assume the output // can be returned to the user directly. -// TODO: We should be able to use the allocation plan to know which device an output will be on. static common::Status SetupFetchesForExecute(const SessionState& session_state, const std::vector& output_names, std::vector& fetches, std::vector& new_fetches, std::vector* copy_to_new_fetches_cached_values) { ORT_ENFORCE(new_fetches.empty()); - - const auto& execution_providers = session_state.GetExecutionProviders(); auto num_outputs = output_names.size(); - new_fetches.resize(num_outputs); + const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); + const auto* exec_plan = session_state.GetExecutionPlan(); // track which fetches can be copied to new_fetches and used directly in the execution. std::vector local_can_copy_flags(num_outputs, false); @@ -275,16 +278,12 @@ static common::Status SetupFetchesForExecute(const SessionState& session_state, continue; } - const auto& node_provider_type = node.GetExecutionProviderType(); - const auto& provided_tensor = provided_mlvalue.Get(); - const auto& provided_tensor_loc = provided_tensor.Location(); - const auto* tensor_provider = execution_providers.Get(provided_tensor_loc); - if (!tensor_provider) { - tensor_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider); - } + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg->Name(), arg_index)); + const auto& planned_device = exec_plan->GetLocation(arg_index).device; + const auto& provided_tensor_device = provided_mlvalue.Get().Location().device; - auto tensor_provider_type = tensor_provider->Type(); - if (node_provider_type == tensor_provider_type) { + if (planned_device == provided_tensor_device) { new_fetches[idx] = fetches[idx]; local_can_copy_flags[idx] = true; continue; @@ -344,43 +343,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state continue; } - auto& fetched_tensor = fetched_mlvalue.Get(); - auto& fetched_tensor_location = fetched_tensor.Location(); - auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location); - if (!p_fetched_provider) { - p_fetched_provider = cpu_execution_provider; - } - - auto fetched_provider_type = p_fetched_provider->Type(); - auto& output_mlvalue = user_fetches[idx]; - const IExecutionProvider* p_output_provider = nullptr; - + auto target_device = OrtDevice(); + auto& output_mlvalue = user_fetches[idx]; if (output_mlvalue.IsAllocated()) { Tensor* p_output_tensor = output_mlvalue.GetMutable(); + target_device = p_output_tensor->Location().device; p_output_provider = execution_providers.Get(p_output_tensor->Location()); } + auto fetch_result_device = fetched_mlvalue.Get().Location().device; + if (target_device == fetch_result_device) { + user_fetches[idx] = fetched_mlvalue; + continue; + } if (!p_output_provider) { p_output_provider = cpu_execution_provider; } - auto output_provider_type = p_output_provider->Type(); - - if (fetched_provider_type == output_provider_type || - (p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) { - user_fetches[idx] = fetched_mlvalue; - continue; - } - needed_copy = true; - - auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider) - ? p_fetched_provider - : p_output_provider; - - const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable. - FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider}; + FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider}; ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue)); if (copiers) { @@ -410,11 +392,7 @@ static common::Status CachedCopyOutputsAcrossDevices( static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) { for (const auto& execution_provider : execution_providers) { - if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider && - execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider && - execution_provider->Type() != onnxruntime::kNGraphExecutionProvider && - execution_provider->Type() != onnxruntime::kNupharExecutionProvider && - execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) { + if (!ProviderIsCpuBased(execution_provider->Type())) { return DeviceCopyCheck::Unknown; } } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index b096f1ecbaf8b..881762da85a2d 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -25,6 +25,8 @@ class Logger; } namespace utils { +void* DefaultAlloc(size_t size); +void DefaultFree(void* p); AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorInfo& allocator_info); diff --git a/onnxruntime/core/graph/automl_ops/automl_defs.cc b/onnxruntime/core/graph/automl_ops/automl_defs.cc new file mode 100644 index 0000000000000..dc4dd653f37c0 --- /dev/null +++ b/onnxruntime/core/graph/automl_ops/automl_defs.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/constants.h" +#include "core/graph/automl_ops/automl_defs.h" +#include "core/graph/op.h" +#include "onnx/defs/schema.h" +#include "onnx/defs/shape_inference.h" + +namespace onnxruntime { +namespace automl { +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL; + +void RegisterAutoMLSchemas() { + + static const char* DateTimeTransformer_ver1_doc = R"DOC( + DateTimeTransformer accepts a single scalar int64 tensor, constructs + an instance of std::chrono::system_clock::time_point and passes it as an argument + to Microsoft::DateTimeFeaturizer which is a part of a shared library. + It returns an instance of TimePoint class. + )DOC"; + + MS_AUTOML_OPERATOR_SCHEMA(DateTimeTransformer) + .SinceVersion(1) + .SetDomain(kMSAutoMLDomain) + .SetDoc(DateTimeTransformer_ver1_doc) + .Input(0, "X", + "The input represents a number of seconds passed since the epoch, suitable to properly construct" + "an instance of std::chrono::system_clock::time_point", + "T1") + .Output(0, "Y", "The output which is a Microsoft::DateTimeFeaturizer::TimePoint structure", "T2") + .TypeConstraint( + "T1", + {"tensor(int64)"}, + "Constrain input type to int64 scalar tensor.") + .TypeConstraint( + "T2", + {"opaque(com.microsoft.automl,DateTimeFeaturizer_TimePoint)"}, + "Constrain output type to an AutoML specific Microsoft::Featurizers::TimePoint type" + "currently not part of ONNX standard. When it becomes a part of the standard we will adjust this" + "kernel definition and move it to ONNX repo"); +} +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/automl_ops/automl_defs.h b/onnxruntime/core/graph/automl_ops/automl_defs.h new file mode 100644 index 0000000000000..b1a37366c396d --- /dev/null +++ b/onnxruntime/core/graph/automl_ops/automl_defs.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { +namespace automl { +#define MS_AUTOML_OPERATOR_SCHEMA(name) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ(Counter, name) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ(Counter, name) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__) + +#define MS_AUTOML_OPERATOR_SCHEMA_ELSEWHERE(name, schema_func) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \ + MS_AUTOML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) +#define MS_AUTOML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ + schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)) + +void RegisterAutoMLSchemas(); +} // namespace automl +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 49d8e3309a989..66d85474461cd 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -21,6 +21,10 @@ void convPoolShapeInference( int input1Idx, int input2Idx); void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx); +void matmulShapeInference( + ONNX_NAMESPACE::InferenceContext& ctx, + int input1Idx, + int input2Idx); } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -1158,6 +1162,39 @@ of [N, 0] then [N, 0]. updateOutputShape(ctx, 0, output_shape); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulInteger16) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(R"DOC( +Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + The production MUST never overflow. The accumulation may overflow if and only if in 32 bits.)DOC") + .Input(0, "A", "N-dimensional matrix A", "T1") + .Input(1, "B", "N-dimensional matrix B", "T2") + .Output(0, "Y", "Matrix multiply results from A * B", "T3") + .TypeConstraint("T1", {"tensor(int16)", "tensor(uint16)"}, "Constrain input A data types as 16-bit integer tensor") + .TypeConstraint("T2", {"tensor(int16)", "tensor(uint16)"}, "Constrain input B data types as 16-bit integer tensor") + .TypeConstraint("T3", + {"tensor(int32)", "tensor(uint32)"}, + "Constrain output Y data types as 32-bit integer tensor." + "T3 must be tensor(uint32) when both T1 and T2 are tensor(uint16)," + "or must be tensor(int32) when either T1 or T2 is tensor(int16).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto a_type = ctx.getInputType(0); + auto b_type = ctx.getInputType(1); + auto y_type = ctx.getOutputType(0); + if (nullptr == a_type || nullptr == b_type || nullptr == y_type || + a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType || + b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) { + fail_type_inference( + "inputs are expected to have tensor type and output type should not be null."); + } + + // Right now we only support int32 + y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32); + + matmulShapeInference(ctx, 0, 1); + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceSumInteger) .SetDomain(kMSDomain) .SinceVersion(1) @@ -1599,4 +1636,4 @@ Example 4: #endif } } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 262a2591ddb08..53ee4047b4994 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -8,6 +8,8 @@ #include "core/graph/graph_viewer.h" +#include "core/graph/graph_utils.h" + namespace onnxruntime { struct NodeCompare { @@ -25,12 +27,13 @@ GraphViewer::GraphViewer(const Graph& graph) { leaf_nodes.push_back(&node); } } - graph.ReverseDFSFrom(leaf_nodes, - nullptr, - [this](const Node* n) { - nodes_in_topological_order_.push_back(n->Index()); - }, - NodeCompare()); + graph.ReverseDFSFrom( + leaf_nodes, + nullptr, + [this](const Node* n) { + nodes_in_topological_order_.push_back(n->Index()); + }, + NodeCompare()); for (auto& node : graph_->Nodes()) { if (node.InputEdgesBegin() == node.InputEdgesEnd()) { @@ -52,6 +55,10 @@ bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONN return graph_->GetInitializedTensor(tensor_name, value); } +bool GraphViewer::CanOverrideInitializer() const noexcept { + return graph_->CanOverrideInitializer(); +} + // Graph inputs excluding initializers. const std::vector& GraphViewer::GetInputs() const noexcept { return graph_->GetInputs(); @@ -109,4 +116,8 @@ bool GraphViewer::IsSubgraph() const { return graph_->IsSubgraph(); } +bool GraphViewer::IsConstantInitializer(const std::string& name, bool check_outer_scope) const { + return graph_utils::IsConstantInitializer(*graph_, name, check_outer_scope); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 3d03b2b4efd5d..d2da1752292ba 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -108,18 +108,25 @@ Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchema const auto& domain = opSet.domain(); const auto version = opSet.version(); // empty domain and 'ai.onnx' are equivalent - if ((domain.empty() || domain == "ai.onnx") && version < 7) { + if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) { // TODO: Check if we can upgrade all the current opset 6 models that are being tested // in CI to opset 7 or above LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped " "with opset version 7 or above for opset domain 'ai.onnx'. " "Please upgrade your model to opset 7 or higher. " "For now, this opset " - << version + << version << " model may run depending upon legacy support " "of some older opset version operators."; } - domain_to_version[domain] = gsl::narrow_cast(version); + // We need to overwrite the domain here with ("") or else the loop below will try to find ("") + // in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11). + // This effectively ignores the opset version specified by the model for the onnx domain. + if (domain == kOnnxDomainAlias) { + domain_to_version[kOnnxDomain] = gsl::narrow_cast(version); + } else { + domain_to_version[domain] = gsl::narrow_cast(version); + } } auto domain_map = schema_registry->GetLatestOpsetVersions(false); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index b1e08f09b567c..884d97042b4bc 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -129,6 +129,27 @@ MlasSgemm( MLAS_THREADPOOL* ThreadPool ); +// +// Quantized integer matrix/matrix multiply routine. +// + +void +MLASCALL +MlasQgemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const uint8_t* B, + size_t ldb, + uint8_t offb, + int32_t* C, + size_t ldc, + MLAS_THREADPOOL* ThreadPool + ); + // // Convolution routines. // diff --git a/onnxruntime/core/mlas/lib/aarch64/sgemma.s b/onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S similarity index 95% rename from onnxruntime/core/mlas/lib/aarch64/sgemma.s rename to onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S index 545465a5a86e8..c69fadc36893b 100644 --- a/onnxruntime/core/mlas/lib/aarch64/sgemma.s +++ b/onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sgemma.s + SgemmKernelNeon.s Abstract: @@ -88,7 +88,7 @@ Abstract: .endm - +// // MultiplyAccumulateRow // // Generates the code to multiply and accumulate a single row of the output @@ -137,11 +137,11 @@ Abstract: ClearBlockAccumulators \Columns\(),\Rows\() .if \Rows\() >= 2 - add x10,x0,x6,uxtw 2 // compute matrix A plus 1 row + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row .endif .if \Rows\() >= 4 - add x11,x10,x6,uxtw 2 // compute matrix A plus 2 rows - add x12,x11,x6,uxtw 2 // compute matrix A plus 3 rows + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows .endif sub x9,x3,#4 // decrement block count to process @@ -183,7 +183,7 @@ Abstract: ldp q6,q7,[x1,#-8*4] .endif MultiplyAccumulateBlock \Columns\(),\Rows\(),0 - sub x9,x9,1 + sub x9,x9,#1 cbnz x9,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy1Loop .L\Mode\().Output\Columns\().x\Rows\().Block: @@ -430,12 +430,12 @@ Return Value: .type MlasSgemmKernel\Mode\(),%function MlasSgemmKernel\Mode\(): - stp d8,d9,[sp,-32]! - stp d10,d11,[sp,16] + stp d8,d9,[sp,#-32]! + stp d10,d11,[sp,#16] - add x13,x2,x7,uxtw 2 // compute matrix C plus 1 row - add x14,x13,x7,uxtw 2 // compute matrix C plus 2 rows - add x15,x14,x7,uxtw 2 // compute matrix C plus 3 rows + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows mov x8,x0 // save matrix A // @@ -452,8 +452,8 @@ MlasSgemmKernel\Mode\(): .L\Mode\().ExitKernel: mov x0,x4 - ldp d10,d11,[sp,16] - ldp d8,d9,[sp],32 + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#32 ret // diff --git a/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc b/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc new file mode 100644 index 0000000000000..02f7d92256017 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc @@ -0,0 +1,232 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; AssembleAvx512Vnni.inc +; +; Abstract: +; +; This module contains macros to build VNNI instructions for toolchains that +; do not natively support this newer instruction set extension. +; +;-- + +; +; Map friendly register names to the encoded register index. +; + +ZmmIndex_zmm0 EQU 0 +ZmmIndex_zmm1 EQU 1 +ZmmIndex_zmm2 EQU 2 +ZmmIndex_zmm3 EQU 3 +ZmmIndex_zmm4 EQU 4 +ZmmIndex_zmm5 EQU 5 +ZmmIndex_zmm6 EQU 6 +ZmmIndex_zmm7 EQU 7 +ZmmIndex_zmm8 EQU 8 +ZmmIndex_zmm9 EQU 9 +ZmmIndex_zmm10 EQU 10 +ZmmIndex_zmm11 EQU 11 +ZmmIndex_zmm12 EQU 12 +ZmmIndex_zmm13 EQU 13 +ZmmIndex_zmm14 EQU 14 +ZmmIndex_zmm15 EQU 15 +ZmmIndex_zmm16 EQU 16 +ZmmIndex_zmm17 EQU 17 +ZmmIndex_zmm18 EQU 18 +ZmmIndex_zmm19 EQU 19 +ZmmIndex_zmm20 EQU 20 +ZmmIndex_zmm21 EQU 21 +ZmmIndex_zmm22 EQU 22 +ZmmIndex_zmm23 EQU 23 +ZmmIndex_zmm24 EQU 24 +ZmmIndex_zmm25 EQU 25 +ZmmIndex_zmm26 EQU 26 +ZmmIndex_zmm27 EQU 27 +ZmmIndex_zmm28 EQU 28 +ZmmIndex_zmm29 EQU 29 +ZmmIndex_zmm30 EQU 30 +ZmmIndex_zmm31 EQU 31 + +GprIndex_rax EQU 0 +GprIndex_rcx EQU 1 +GprIndex_rdx EQU 2 +GprIndex_rbx EQU 3 +GprIndex_rbp EQU 5 +GprIndex_rsi EQU 6 +GprIndex_rdi EQU 7 +GprIndex_r8 EQU 8 +GprIndex_r9 EQU 9 +GprIndex_r10 EQU 10 +GprIndex_r11 EQU 11 +GprIndex_r12 EQU 12 +GprIndex_r13 EQU 13 +GprIndex_r14 EQU 14 +GprIndex_r15 EQU 15 + +; +; Macro Description: +; +; This macro builds a VNNI instruction of the form: +; +; instr zmm1,zmm2,zmm3 +; +; Arguments: +; +; Opcode - Specifies the opcode for the VNNI instruction. +; +; DestReg - Specifies the destination register. +; +; Src1Reg - Specifies the first source register. +; +; Src2Reg - Specifies the second source register. +; + +VnniZmmZmmZmm MACRO Opcode, DestReg, Src1Reg, Src2Reg + + LOCAL Payload0, Payload1, Payload2, ModRMByte + + Payload0 = 002h ; "0F 38" prefix + Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) + Payload0 = Payload0 + ((((ZmmIndex_&Src2Reg& SHR 4) AND 1) XOR 1) SHL 6) + Payload0 = Payload0 + ((((ZmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) + Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 4) AND 1) XOR 1) SHL 4) + + Payload1 = 005h ; "66" prefix + Payload1 = Payload1 + (((ZmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) + + Payload2 = 040h ; 512-bit vector length + Payload2 = Payload2 + ((((ZmmIndex_&Src1Reg& SHR 4) AND 1) XOR 1) SHL 3) + + ModRMByte = 0C0h ; register form + ModRMByte = ModRMByte + ((ZmmIndex_&DestReg& AND 7) SHL 3) + ModRMByte = ModRMByte + (ZmmIndex_&Src2Reg& AND 7) + + db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte + + ENDM + +VpdpbusdZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 050h, DestReg, Src1Reg, Src2Reg + + ENDM + +VpdpbusdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 051h, DestReg, Src1Reg, Src2Reg + + ENDM + +VpdpwssdZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 052h, DestReg, Src1Reg, Src2Reg + + ENDM + +VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 053h, DestReg, Src1Reg, Src2Reg + + ENDM + +; +; Macro Description: +; +; This macro builds a VNNI instruction of the form: +; +; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale] +; +; Arguments: +; +; Opcode - Specifies the opcode for the VNNI instruction. +; +; DestReg - Specifies the destination register. +; +; Src1Reg - Specifies the first source register. +; +; BaseReg - Specifies the base register of the broadcast operand. +; +; IndexReg - Specifies the optional index register of the broadcast operand. +; +; Scale - Specifies the scaling factor of the optional index register. +; + +VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + LOCAL Payload0, Payload1, Payload2, ModRMByte, SibByte + + Payload0 = 002h ; "0F 38" prefix + Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) +IFNB + Payload0 = Payload0 + ((((GprIndex_&IndexReg& SHR 3) AND 1) XOR 1) SHL 6) +ELSE + Payload0 = Payload0 + 040h ; zero logical index register +ENDIF + Payload0 = Payload0 + ((((GprIndex_&BaseReg& SHR 3) AND 1) XOR 1) SHL 5) + Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 4) AND 1) XOR 1) SHL 4) + + Payload1 = 005h ; "66" prefix + Payload1 = Payload1 + (((ZmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) + + Payload2 = 050h ; 512-bit vector length, broadcast + Payload2 = Payload2 + ((((ZmmIndex_&Src1Reg& SHR 4) AND 1) XOR 1) SHL 3) + + ModRMByte = 000h ; memory form + ModRMByte = ModRMByte + ((ZmmIndex_&DestReg& AND 7) SHL 3) +IFNB + ModRMByte = ModRMByte + 004h ; indicate SIB byte needed +ELSE + ModRMByte = ModRMByte + (GprIndex_&BaseReg& AND 7) +ENDIF + +IFNB + SibByte = 0 +IF Scale EQ 2 + SibByte = SibByte + (1 SHL 6) +ELSEIF Scale EQ 4 + SibByte = SibByte + (2 SHL 6) +ELSEIF Scale EQ 8 + SibByte = SibByte + (3 SHL 6) +ELSEIF Scale NE 1 + .err +ENDIF + SibByte = SibByte + ((GprIndex_&IndexReg& AND 7) SHL 3) + SibByte = SibByte + (GprIndex_&BaseReg& AND 7) +ENDIF + +IFNB + db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte, SibByte +ELSE + db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte +ENDIF + + ENDM + +VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + ENDM + +VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + ENDM + +VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + ENDM + +VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm new file mode 100644 index 0000000000000..365348a14db1f --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm @@ -0,0 +1,1241 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8U8KernelAvx2.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX2 instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + EXTERN MlasMaskMoveAvx:NEAR + +; +; Stack frame layout for the U8U8 CopyPackA routine. +; + +GemmU8U8CopyPackAFrame STRUCT + + PaddedMatrixAData OWORD 4 DUP (?) + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + Padding QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountK QWORD ? + RowSumVector QWORD ? + offb QWORD ? + +GemmU8U8CopyPackAFrame ENDS + +; +; Stack frame layout for the U8U8 CopyPackB routine. +; + +GemmU8U8CopyPackBFrame STRUCT + + PaddedMatrixBData OWORD 2 DUP (?) + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountK QWORD ? + ColumnSumVector QWORD ? + offa QWORD ? + +GemmU8U8CopyPackBFrame ENDS + +; +; Stack frame layout for the U8U8 kernel. +; + +GemmU8U8KernelFrame STRUCT + + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + SavedXmm10 OWORD ? + SavedXmm11 OWORD ? + SavedXmm12 OWORD ? + SavedXmm13 OWORD ? + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + SavedR14 QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumVector QWORD ? + ColumnSumVector QWORD ? + DepthValue QWORD ? + ZeroMode QWORD ? + +GemmU8U8KernelFrame ENDS + +;++ +; +; Routine Description: +; +; This routine copies elements from the source matrix to the destination +; packed buffer. +; +; The kernel expects that elements from matrix A have been zero extended to +; 16-bits and padded to a multiple of 32-bits (two pairs of 16-bit values). +; The kernel can then efficiently broadcast 32-bits from the packed buffer +; and avoid expensive shuffling inside the kernel. +; +; Arguments: +; +; D (rcx) - Supplies the address of the destination packed buffer. +; +; A (rdx) - Supplies the address of the source matrix. +; +; lda (r8) - Supplies the number of elements per row of the source matrix. +; +; CountM (r9) - Supplies the number of rows of the source matrix to copy. +; +; CountK - Supplies the number of columns of the source matrix to copy. +; +; RowSumVector - Supplies the address of the buffer to receive the sums of +; the elements from each of the rows. Each sum has also been multiplied +; by the zero point offset. +; +; offb - Supplies the zero point offset for the other source matrix of the +; matrix multiplication. +; +; Return Value: +; +; None. +; +;-- + + NESTED_ENTRY MlasGemmU8U8CopyPackAAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + alloc_stack (GemmU8U8CopyPackAFrame.SavedR13) + save_xmm128_avx xmm6,GemmU8U8CopyPackAFrame.SavedXmm6 + save_xmm128_avx xmm7,GemmU8U8CopyPackAFrame.SavedXmm7 + save_xmm128_avx xmm8,GemmU8U8CopyPackAFrame.SavedXmm8 + save_xmm128_avx xmm9,GemmU8U8CopyPackAFrame.SavedXmm9 + + END_PROLOGUE + + mov rdi,rcx + mov rsi,rdx + mov r10,GemmU8U8CopyPackAFrame.CountK[rsp] + lea r11,[r10+1] + and r11,NOT 1 ; align CountK up to pair count + mov r12,GemmU8U8CopyPackAFrame.RowSumVector[rsp] + vpbroadcastw xmm8,WORD PTR GemmU8U8CopyPackAFrame.offb[rsp] + +; +; Compute the conditional load/store mask for an unaligned CountK. +; + + mov eax,r10d + and eax,15 ; isolate unaligned count + inc eax + shr eax,1 ; align unaligned count to pair count + mov DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp],eax + vpbroadcastd ymm9,DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp] + vpcmpgtd ymm9,ymm9,YMMWORD PTR [MlasMaskMoveAvx] + +; +; Zero initialize the padded stack buffers. +; + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 + vmovdqu YMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 + +; +; Process 4 rows of matrix A in a loop. +; +; For each row, zero extend the source bytes to 16-bits and write to the packed +; buffer. The packed buffer has the same data ordering as the source bytes, but +; the stride is CountK aligned up to an even number of 16-bit values. +; +; These 16-bit values are also accumulated into an intermediate per-row +; accumulator. CountK cannot be greater than 256 to avoid overflowing these +; 16-bit accumulators. +; + + sub r9,4 + jb ProcessRemainingRows + +ProcessNextRowM4: + vpxor xmm0,xmm0,xmm0 ; clear row accumulators + vpxor xmm1,xmm1,xmm1 + vpxor xmm2,xmm2,xmm2 + vpxor xmm3,xmm3,xmm3 + mov rdx,rsi + mov rcx,rdi + lea rsi,[rsi+r8*4] ; advance next matrix A by 4 rows + lea rdi,[rdi+r11*(2*4)] ; advance next matrix D by 4 rows + mov rbx,r10 ; reload columns remaining + sub rbx,16 + jb ProcessRemainingColumnsM4 + +ProcessNextColumnLoopM4: + lea rax,[rdx+r8*2] ; compute matrix A plus two rows + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vpmovzxbw ymm5,XMMWORD PTR [rdx+r8] + vpmovzxbw ymm6,XMMWORD PTR [rax] + vpmovzxbw ymm7,XMMWORD PTR [rax+r8] + lea rax,[rcx+r11*4] ; compute matrix D plus two rows + vmovdqu YMMWORD PTR [rcx],ymm4 + vmovdqu YMMWORD PTR [rcx+r11*2],ymm5 + vmovdqu YMMWORD PTR [rax],ymm6 + vmovdqu YMMWORD PTR [rax+r11*2],ymm7 + vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns + vpaddw ymm1,ymm1,ymm5 + vpaddw ymm2,ymm2,ymm6 + vpaddw ymm3,ymm3,ymm7 + add rdx,16 ; advance matrix A by 16 bytes + add rcx,16*2 ; advance matrix D by 16 words + sub rbx,16 ; subtract columns remaining + jae ProcessNextColumnLoopM4 + +ProcessRemainingColumnsM4: + add rbx,16 ; correct for over-subtract above + jz ReduceRowSumVectorM4 + +; +; Copy the unaligned CountK columns to a zero padded stack buffer. +; + +.errnz GemmU8U8CopyPackAFrame.PaddedMatrixAData + mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData + test bl,8 ; (CountK & 8) != 0? + jz CopyRemainingCountKLessThan8M4 + lea r13,[rdx+r8*2] ; compute matrix A plus two rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r8] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [r13] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [r13+r8] + mov QWORD PTR [rbp+48],rax + add rdx,8 + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountKLessThan8M4: + test bl,4 ; (CountK & 4) != 0? + jz CopyRemainingCountKLessThan4M4 + lea r13,[rdx+r8*2] ; compute matrix A plus two rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r8] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [r13] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [r13+r8] + mov DWORD PTR [rbp+48],eax + add rdx,4 + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountKLessThan4M4: + test bl,2 ; (CountK & 2) != 0? + jz CopyRemainingCountKLessThan2M4 + lea r13,[rdx+r8*2] ; compute matrix A plus two rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r8] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [r13] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [r13+r8] + mov WORD PTR [rbp+48],ax + add rdx,2 + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountKLessThan2M4: + test bl,1 ; (CountK & 1) != 0? + jz ProcessPaddedMatrixADataM4 + lea r13,[rdx+r8*2] ; compute matrix A plus two rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r8] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [r13] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [r13+r8] + mov BYTE PTR [rbp+48],al + +; +; Process the remaining CountK columns using the zero padded stack buffer. +; + +ProcessPaddedMatrixADataM4: + vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp] + vpmovzxbw ymm5,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+16] + vpmovzxbw ymm6,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32] + vpmovzxbw ymm7,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+48] + lea rax,[rcx+r11*4] ; compute matrix D plus two rows + vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 + vpmaskmovd YMMWORD PTR [rcx+r11*2],ymm9,ymm5 + vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 + vpmaskmovd YMMWORD PTR [rax+r11*2],ymm9,ymm7 + vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns + vpaddw ymm1,ymm1,ymm5 + vpaddw ymm2,ymm2,ymm6 + vpaddw ymm3,ymm3,ymm7 + +; +; Reduce the sums for the four rows of output. Transpose the intermediate +; accumulators by treating the registers as 32-bit elements containing a pair +; of 16-bit sums. Continue reducing the transposed accumulators to produce the +; final 32-bit vector output. +; + +ReduceRowSumVectorM4: + vpunpckldq ymm4,ymm0,ymm1 ; [A5 B5 A4 B4 A1 B1 A0 B0] + vpunpckhdq ymm5,ymm0,ymm1 ; [A7 B7 A6 B6 A3 B3 A2 B2] + vpunpckldq ymm6,ymm2,ymm3 ; [C5 D5 C4 D4 C1 D1 C0 D0] + vpunpckhdq ymm7,ymm2,ymm3 ; [C7 D7 C6 D6 C3 D3 C2 D2] + vpunpcklqdq ymm0,ymm4,ymm6 ; [A4 B4 C4 D4 A0 B0 C0 D0] + vpunpckhqdq ymm1,ymm4,ymm6 ; [A5 B5 C5 D5 A1 B1 C1 D1] + vpunpcklqdq ymm2,ymm5,ymm7 ; [A6 B6 C6 D6 A2 B2 C2 D2] + vpunpckhqdq ymm3,ymm5,ymm7 ; [A7 B7 C7 D7 A3 B3 C3 D3] + vpaddw ymm0,ymm0,ymm1 ; reduction + vpaddw ymm0,ymm0,ymm2 + vpaddw ymm0,ymm0,ymm3 + vextracti128 xmm1,ymm0,1 ; extract high pairs + vpaddw xmm0,xmm0,xmm1 ; reduction + vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce + vmovdqu XMMWORD PTR [r12],xmm0 + add r12,4*4 ; advance row sum vector by 4 dwords + sub r9,4 ; subtract rows remaining + jae ProcessNextRowM4 + +ProcessRemainingRows: + add r9,4 ; correct for over-subtract above + jz ExitRoutine + +; +; Process a single row of matrix A in a loop. +; + +ProcessNextRowM1: + vpxor xmm0,xmm0,xmm0 ; clear row accumulator + mov rdx,rsi + mov rcx,rdi + add rsi,r8 + lea rdi,[rdi+r11*2] + mov rbx,r10 ; reload columns remaining + sub rbx,16 + jb ProcessRemainingColumnsM1 + +ProcessNextColumnLoopM1: + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rcx],ymm4 + vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns + add rdx,16 ; advance matrix A by 16 bytes + add rcx,16*2 ; advance matrix D by 16 words + sub rbx,16 ; subtract columns remaining + jae ProcessNextColumnLoopM1 + +ProcessRemainingColumnsM1: + add rbx,16 ; correct for over-subtract above + jz ReduceRowSumVectorM1 + +; +; Copy the unaligned CountK columns to a zero padded stack buffer. +; + +.errnz GemmU8U8CopyPackAFrame.PaddedMatrixAData + mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData + test bl,8 ; (CountK & 8) != 0? + jz CopyRemainingCountKLessThan8M1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountKLessThan8M1: + test bl,4 ; (CountK & 4) != 0? + jz CopyRemainingCountKLessThan4M1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountKLessThan4M1: + test bl,2 ; (CountK & 2) != 0? + jz CopyRemainingCountKLessThan2M1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountKLessThan2M1: + test bl,1 ; (CountK & 1) != 0? + jz ProcessPaddedMatrixADataM1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +; +; Process the remaining CountK columns using the zero padded stack buffer. +; + +ProcessPaddedMatrixADataM1: + vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp] + vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 + vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns + +; +; Reduce the sum for the single row of output. +; + +ReduceRowSumVectorM1: + vextracti128 xmm1,ymm0,1 ; extract high pairs + vpaddw xmm0,xmm0,xmm1 ; reduction + vphaddw xmm0,xmm0,xmm0 + vphaddw xmm0,xmm0,xmm0 + vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce + vmovd DWORD PTR [r12],xmm0 + add r12,4 ; advance row sum vector by 1 DWORD + dec r9 ; decrement rows remaining + jnz ProcessNextRowM1 + +; +; Restore non-volatile registers and return. +; + +ExitRoutine: + vzeroupper + vmovaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp] + vmovaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp] + vmovaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp] + vmovaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp] + add rsp,(GemmU8U8CopyPackAFrame.SavedR13) + + BEGIN_EPILOGUE + + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + + NESTED_END MlasGemmU8U8CopyPackAAvx2, _TEXT + +;++ +; +; Routine Description: +; +; This routine copies elements from the source matrix to the destination +; packed buffer. +; +; Arguments: +; +; D (rcx) - Supplies the address of the destination packed buffer. +; +; B (rdx) - Supplies the address of the source matrix. +; +; ldb (r8) - Supplies the number of elements per row of the source matrix. +; +; CountN (r9) - Supplies the number of columns of the source matrix to copy. +; +; CountK - Supplies the number of rows of the source matrix to copy. +; +; ColumnSumVector - Supplies the address of the buffer to receive the sums of +; the elements from each of the columns. Each sum has also been multiplied +; by the zero point offset. +; +; offa - Supplies the zero point offset for the other source matrix of the +; matrix multiplication. +; +; Return Value: +; +; None. +; +;-- + + NESTED_ENTRY MlasGemmU8U8CopyPackBAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + alloc_stack (GemmU8U8CopyPackBFrame.SavedRsi) + + END_PROLOGUE + + mov rsi,rdx + mov r10,GemmU8U8CopyPackBFrame.CountK[rsp] + mov r11,GemmU8U8CopyPackBFrame.ColumnSumVector[rsp] + vpbroadcastw ymm5,WORD PTR GemmU8U8CopyPackBFrame.offa[rsp] + +; +; Zero initialize the padded stack buffers. +; + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp],ymm0 + +; +; Process 16 columns of matrix B in a loop. +; + + sub r9,16 + jb ProcessRemainingColumns + +ProcessNextColumnN16: + vpxor xmm0,xmm0,xmm0 ; clear column accumulators + vpxor xmm1,xmm1,xmm1 + mov rdx,rsi + add rsi,16 ; advance next matrix B by 16 columns + mov rbx,r10 ; reload rows remaining + sub rbx,2 + jb ProcessRemainingRowsN16 + +ProcessNextRowLoopN16: + vmovdqu xmm2,XMMWORD PTR [rdx] ; load two rows + vmovdqu xmm3,XMMWORD PTR [rdx+r8] + lea rdx,[rdx+r8*2] ; advance matrix B by two rows + vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows + vmovdqu XMMWORD PTR [rcx+16],xmm3 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + add rcx,32 ; advance matrix D by 32 bytes + vpaddw ymm0,ymm0,ymm4 ; accumulate per column + vpaddw ymm1,ymm1,ymm3 + sub rbx,2 ; subtract columns remaining + jae ProcessNextRowLoopN16 + +ProcessRemainingRowsN16: + add rbx,2 ; correct for over-subtract above + jz ReduceColumnSumVectorN16 + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows + vextracti128 xmm3,ymm4,1 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 ; accumulate per column + vpaddw ymm1,ymm1,ymm3 + add rcx,32 ; advance matrix D by 32 bytes + +ReduceColumnSumVectorN16: + vpmaddwd ymm0,ymm0,ymm5 ; multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm5 ; multiply by offset and reduce + vmovdqu YMMWORD PTR [r11],ymm0 + vmovdqu YMMWORD PTR [r11+32],ymm1 + add r11,64 ; advance column sum vector by 16 dwords + sub r9,16 ; subtract columns remaining + jae ProcessNextColumnN16 + +ProcessRemainingColumns: + add r9,16 ; correct for over-subtract above + jnz ProcessColumnNUnaligned + +; +; Restore non-volatile registers and return. +; + +ExitRoutine: + vzeroupper + add rsp,(GemmU8U8CopyPackBFrame.SavedRsi) + + BEGIN_EPILOGUE + + pop rsi + pop rbx + pop rbp + ret + +; +; Process the remaining columns of matrix B. +; + +ProcessColumnNUnaligned: + vpxor xmm0,xmm0,xmm0 ; clear column accumulators + vpxor xmm1,xmm1,xmm1 + sub r10,2 + jb ProcessRemainingRowsNUnaligned + +ProcessNextRowLoopNUnaligned: + mov rdx,rsi +.errnz GemmU8U8CopyPackBFrame.PaddedMatrixBData + mov rbp,rsp ; GemmU8U8CopyPackBFrame.PaddedMatrixBData + test r9b,8 ; (CountN & 8) != 0? + jz CopyRemainingCountNLessThan8K2 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r8] + mov QWORD PTR [rbp+16],rax + add rdx,8 ; advance matrix B + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountNLessThan8K2: + test r9b,4 ; (CountN & 4) != 0? + jz CopyRemainingCountNLessThan4K2 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r8] + mov DWORD PTR [rbp+16],eax + add rdx,4 ; advance matrix B + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountNLessThan4K2: + test r9b,2 ; (CountN & 2) != 0? + jz CopyRemainingCountNLessThan2K2 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r8] + mov WORD PTR [rbp+16],ax + add rdx,2 ; advance matrix B + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountNLessThan2K2: + test r9b,1 ; (CountN & 1) != 0? + jz ProcessPaddedMatrixBDataK2 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r8] + mov BYTE PTR [rbp+16],al + +ProcessPaddedMatrixBDataK2: + vmovdqu xmm2,XMMWORD PTR XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp] + vmovdqu xmm3,XMMWORD PTR XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp+16] + vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows + vmovdqu XMMWORD PTR [rcx+16],xmm3 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 ; accumulate per column + vpaddw ymm1,ymm1,ymm3 + lea rsi,[rsi+r8*2] ; advance next matrix B by two rows + add rcx,32 ; advance matrix D by 32 bytes + sub r10,2 ; subtract columns remaining + jae ProcessNextRowLoopNUnaligned + +ProcessRemainingRowsNUnaligned: + add r10,2 + jz ReduceColumnSumVectorNUnaligned + mov rdx,rsi +.errnz GemmU8U8CopyPackBFrame.PaddedMatrixBData + mov rbp,rsp ; GemmU8U8CopyPackBFrame.PaddedMatrixBData + test r9b,8 ; (CountN & 8) != 0? + jz CopyRemainingCountNLessThan8K1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 ; advance matrix B + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountNLessThan8K1: + test r9b,4 ; (CountN & 4) != 0? + jz CopyRemainingCountNLessThan4K1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 ; advance matrix B + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountNLessThan4K1: + test r9b,2 ; (CountN & 2) != 0? + jz CopyRemainingCountNLessThan2K1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 ; advance matrix B + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountNLessThan2K1: + test r9b,1 ; (CountN & 1) != 0? + jz ProcessPaddedMatrixBDataK1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +ProcessPaddedMatrixBDataK1: + vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp] + vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows + vextracti128 xmm3,ymm4,1 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 ; accumulate per column + vpaddw ymm1,ymm1,ymm3 + +ReduceColumnSumVectorNUnaligned: + vpmaddwd ymm0,ymm0,ymm5 ; multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm5 ; multiply by offset and reduce + vmovdqu YMMWORD PTR [r11],ymm0 + vmovdqu YMMWORD PTR [r11+32],ymm1 + jmp ExitRoutine + + NESTED_END MlasGemmU8U8CopyPackBAvx2, _TEXT + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 16). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; ymm0 - Supplies the first vector loaded from matrix B. +; +; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 16). +; +; ymm2 - Supplies the broadcast value loaded from matrix A. +; + +MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg + +IF ColumnCount EQ 16 + vpmaddwd ymm3,ymm2,ymm0 + vpaddd Vec1Reg,Vec1Reg,ymm3 + vpmaddwd ymm2,ymm2,ymm1 + vpaddd Vec2Reg,Vec2Reg,ymm2 +ELSE + vpmaddwd ymm3,ymm2,ymm0 + vpaddd Vec2Reg,Vec2Reg,ymm3 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + + ENDM + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the number of paired columns from matrix A and the number of +; paired rows from matrix B to iterate over. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + + LOCAL ComputeBlockLoop + LOCAL ProcessRemainingBlocks + LOCAL ComputeBlockLoopExit + +; +; Initialize the accumulators with the sum of the global depth value constant, +; the column sums, and the row sums. +; + + vpbroadcastd ymm1,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp] +IF ColumnCount EQ 16 + vpaddd ymm0,ymm1,YMMWORD PTR [r13] + vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] + add r13,16*4 ; advance ColumnSumVector by 16 columns +ELSE + vpaddd ymm1,ymm1,YMMWORD PTR [r13] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + +; +; Iterate over PairedCountK elements from matrix A and matrix B. +; +; Unrolling the loop to do two iterations improves performance slightly at the +; cost of larger code size. Balance this by only unrolling for the common case +; of computing 16 columns for an even number of rows. +; + + mov rsi,r9 ; reload PairedCountK +IF RowCount GT 3 + lea rbx,[r10*2+r10] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + +IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0) + sub rsi,2 + jb ProcessRemainingBlocks + +ComputeBlockLoop: + ComputeBlock ColumnCount, RowCount, 0, 0 + ComputeBlock ColumnCount, RowCount, 32, 4 + add rcx,2*4 ; advance matrix A by 2 pairs +IF RowCount GT 3 + add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs +ENDIF + add rdx,2*32 ; advance matrix B by 64 columns + sub rsi,2 ; subtract pairs remaining + jae ComputeBlockLoop + +ProcessRemainingBlocks: + add rsi,2 ; correct for over-subtract above + jz ComputeBlockLoopExit + ComputeBlock ColumnCount, RowCount, 0, 0 + add rdx,32 ; advance matrix B by 32 columns +ELSE +ComputeBlockLoop: + ComputeBlock ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 pair +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 pair +ENDIF + add rdx,32 + dec rsi ; decrement pairs remaining + jnz ComputeBlockLoop +ENDIF + +ComputeBlockLoopExit: +IF RowCount GT 3 + lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows + add rbx,rax +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Fallthrough - Supplies a non-blank value if the macro may fall through to +; the ExitKernel label. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the number of paired columns from matrix A and the number of +; paired rows from matrix B to iterate over. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; +; r14b - Supplies the zero mode flag. +; + +ProcessCountM MACRO RowCount, Fallthrough + + LOCAL ProcessNextColumnLoop16xN + LOCAL SkipAccumulateOutput16xNBlock + LOCAL OutputMasked16xNBlock + LOCAL ProcessRemainingCountN + LOCAL SkipAccumulateOutput8xNBlock + LOCAL SkipAccumulateOutputMasked16xNBlock + LOCAL OutputMasked8xNBlock + LOCAL SkipAccumulateOutputMasked8xNBlock + + cmp rbp,8 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop16xN: + ProduceOutputBlock 16, RowCount + sub rbp,16 + jb OutputMasked16xNBlock + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,8 + ja ProcessNextColumnLoop16xN + test rbp,rbp + jz ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 8, RowCount + cmp rbp,8 + jb OutputMasked8xNBlock + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutput8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + jmp ExitKernel + +OutputMasked16xNBlock: + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutputMasked16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,8*4 ; advance matrix C by 8 columns +IF RowCount GT 3 + add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns +ENDIF + add rbp,8 ; correct for over-subtract above + +OutputMasked8xNBlock: + mov DWORD PTR GemmU8U8KernelFrame.CountN[rsp],ebp + vpbroadcastd ymm0,DWORD PTR GemmU8U8KernelFrame.CountN[rsp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx] + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutputMasked8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, +IFB + jmp ExitKernel +ENDIF + + ENDM + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8U8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8U8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; PairedCountK (r9) - Supplies the number of paired columns from matrix A and +; the number of paired rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumVector - Supplies the sum of each row from matrix A multiplied by the +; zero point offset of matrix B. These values are accumulated into every +; row of matrix C. +; +; ColumnSumVector - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; DepthValue - Supplies the value CountK multiplied by the zero point offset +; of matrixA multplied by the zero point offset of matrix B. This value is +; accumulated into every element of matrix C. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemmU8U8KernelAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + push_reg r14 + alloc_stack (GemmU8U8KernelFrame.SavedR14) + save_xmm128_avx xmm6,GemmU8U8KernelFrame.SavedXmm6 + save_xmm128_avx xmm7,GemmU8U8KernelFrame.SavedXmm7 + save_xmm128_avx xmm8,GemmU8U8KernelFrame.SavedXmm8 + save_xmm128_avx xmm9,GemmU8U8KernelFrame.SavedXmm9 + save_xmm128_avx xmm10,GemmU8U8KernelFrame.SavedXmm10 + save_xmm128_avx xmm11,GemmU8U8KernelFrame.SavedXmm11 + save_xmm128_avx xmm12,GemmU8U8KernelFrame.SavedXmm12 + save_xmm128_avx xmm13,GemmU8U8KernelFrame.SavedXmm13 + save_xmm128_avx xmm14,GemmU8U8KernelFrame.SavedXmm14 + save_xmm128_avx xmm15,GemmU8U8KernelFrame.SavedXmm15 + + END_PROLOGUE + + mov rdi,rcx + mov rbp,GemmU8U8KernelFrame.CountN[rsp] + mov rax,GemmU8U8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + lea r10,[r9*4] + mov r11,GemmU8U8KernelFrame.CountM[rsp] + mov r12,GemmU8U8KernelFrame.RowSumVector[rsp] + mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp] + movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp] + +; +; Process CountM rows of the matrices. +; + + cmp r11,5 + ja ProcessCountM6 + je ProcessCountM5 + cmp r11,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp r11,1 + je ProcessCountM1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM4: + ProcessCountM 4 + +ProcessCountM6: + mov r11d,6 ; return 6 rows handled + ProcessCountM 6, Fallthrough + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + mov eax,r11d + vzeroupper + vmovaps xmm6,GemmU8U8KernelFrame.SavedXmm6[rsp] + vmovaps xmm7,GemmU8U8KernelFrame.SavedXmm7[rsp] + vmovaps xmm8,GemmU8U8KernelFrame.SavedXmm8[rsp] + vmovaps xmm9,GemmU8U8KernelFrame.SavedXmm9[rsp] + vmovaps xmm10,GemmU8U8KernelFrame.SavedXmm10[rsp] + vmovaps xmm11,GemmU8U8KernelFrame.SavedXmm11[rsp] + vmovaps xmm12,GemmU8U8KernelFrame.SavedXmm12[rsp] + vmovaps xmm13,GemmU8U8KernelFrame.SavedXmm13[rsp] + vmovaps xmm14,GemmU8U8KernelFrame.SavedXmm14[rsp] + vmovaps xmm15,GemmU8U8KernelFrame.SavedXmm15[rsp] + add rsp,(GemmU8U8KernelFrame.SavedR14) + + BEGIN_EPILOGUE + + pop r14 + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM3: + ProcessCountM 3 + +ProcessCountM5: + ProcessCountM 5 + + NESTED_END MlasGemmU8U8KernelAvx2, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm new file mode 100644 index 0000000000000..8f4d0fa47f7e2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm @@ -0,0 +1,114 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8U8KernelAvx512BW.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX512BW instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE QgemmU8U8KernelAvx512Common.inc + .list + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 32). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; zmm28 - Supplies the first vector loaded from matrix B. +; +; zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 32). +; +; zmm30 - Supplies the broadcast value loaded from matrix A. +; + +MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg + +IF ColumnCount EQ 32 + vpmaddwd zmm31,zmm30,zmm28 + vpaddd Vec1Reg,Vec1Reg,zmm31 + vpmaddwd zmm30,zmm30,zmm29 + vpaddd Vec2Reg,Vec2Reg,zmm30 +ELSE + vpmaddwd zmm31,zmm30,zmm28 + vpaddd Vec2Reg,Vec2Reg,zmm31 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; zmm16-zmm27 - Supplies the block accumulators. +; + +ComputeBlock MACRO ColumnCount, RowCount + + vpmovzxbw zmm28,YMMWORD PTR [rdx] + EmitIfCountGE ColumnCount, 32, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + + ENDM + +; +; Generate the GEMM kernel. +; + +GemmU8U8KernelAvx512Function Avx512BW + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc new file mode 100644 index 0000000000000..1cd5cdc732b12 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc @@ -0,0 +1,385 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8U8KernelAvx512Common.inc +; +; Abstract: +; +; This module contains common kernel macros and structures for the quantized +; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and +; AVX512VNNI kernels. +; +;-- + +; +; Stack frame layout for the U8U8 kernel. +; + +GemmU8U8KernelFrame STRUCT + + SavedR14 QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumVector QWORD ? + ColumnSumVector QWORD ? + DepthValue QWORD ? + ZeroMode QWORD ? + +GemmU8U8KernelFrame ENDS + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the number of paired columns from matrix A and the number of +; paired rows from matrix B to iterate over. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + + LOCAL ComputeBlockLoop + +; +; Initialize the accumulators with the sum of the global depth value constant, +; the column sums, and the row sums. +; + + vpbroadcastd zmm31,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp] +IF ColumnCount EQ 32 + vpaddd zmm30,zmm31,ZMMWORD PTR [r13] + vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64] + add r13,32*4 ; advance ColumnSumVector by 32 columns +ELSE + vpaddd zmm31,zmm31,ZMMWORD PTR [r13] +ENDIF + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCountGE RowCount, 6, + +; +; Iterate over PairedCountK elements from matrix A and matrix B. +; + + mov rsi,r9 ; reload PairedCountK +IF RowCount GT 3 + lea rbx,[r10*2+r10] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + +ComputeBlockLoop: + ComputeBlock ColumnCount, RowCount + add rcx,4 ; advance matrix A by 1 pair +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 pair +ENDIF + add rdx,32 + dec rsi ; decrement pairs remaining + jnz ComputeBlockLoop + +IF RowCount GT 3 + lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows + add rbx,rax +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the number of paired columns from matrix A and the number of +; paired rows from matrix B to iterate over. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; +; r14b - Supplies the zero mode flag. +; + +ProcessCountM MACRO RowCount + + LOCAL ProcessNextColumnLoop32xN + LOCAL SkipAccumulateOutput32xNBlock + LOCAL Output16xNBlock + LOCAL Output16xNBlockWithMask + LOCAL SkipAccumulateOutput16xNBlockWithMask + LOCAL ProcessRemainingCountN + + cmp rbp,16 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop32xN: + ProduceOutputBlock 32, RowCount + lea rdx,[rdx+r10*8] ; advance matrix B by 8*PairedCountK + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutput32xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput32xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns +IF RowCount GT 3 + add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns +ENDIF + sub rbp,16 + +Output16xNBlock: + sub rbp,16 + jae Output16xNBlockWithMask + lea ecx,[ebp+16] ; correct for over-subtract above + mov esi,1 + shl esi,cl + dec esi + kmovw k1,esi ; update mask for remaining columns + xor ebp,ebp ; no more columns remaining + +Output16xNBlockWithMask: + test r14b,r14b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlockWithMask + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlockWithMask: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,16 + ja ProcessNextColumnLoop32xN + test rbp,rbp + jz ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 16, RowCount + jmp Output16xNBlock + + ENDM + +; +; Macro Description: +; +; This macro generates the common AVX512 code for the inner kernel to compute +; matrix multiplication. +; +; Arguments: +; +; Isa - Supplies the instruction set architecture string for function tags. +; + +GemmU8U8KernelAvx512Function MACRO Isa + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8U8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8U8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; PairedCountK (r9) - Supplies the number of paired columns from matrix A and +; the number of paired rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumVector - Supplies the sum of each row from matrix A multiplied by the +; zero point offset of matrix B. These values are accumulated into every +; row of matrix C. +; +; ColumnSumVector - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; DepthValue - Supplies the value CountK multiplied by the zero point offset +; of matrixA multplied by the zero point offset of matrix B. This value is +; accumulated into every element of matrix C. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemmU8U8Kernel&Isa&, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + push_reg r14 + + END_PROLOGUE + + mov rdi,rcx + mov rbp,GemmU8U8KernelFrame.CountN[rsp] + mov rax,GemmU8U8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + lea r10,[r9*4] + mov r11,GemmU8U8KernelFrame.CountM[rsp] + mov r12,GemmU8U8KernelFrame.RowSumVector[rsp] + mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp] + movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp] + mov esi,-1 + kmovw k1,esi ; update mask to write all columns + +; +; Process CountM rows of the matrices. +; + + cmp r11,5 + ja ProcessCountM6 + je ProcessCountM5 + cmp r11,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp r11,1 + je ProcessCountM1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM4: + ProcessCountM 4 + +ProcessCountM6: + mov r11d,6 ; return 6 rows handled + ProcessCountM 6 + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + mov eax,r11d + + BEGIN_EPILOGUE + + pop r14 + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM3: + ProcessCountM 3 + +ProcessCountM5: + ProcessCountM 5 + + NESTED_END MlasGemmU8U8Kernel&Isa&, _TEXT + + ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm new file mode 100644 index 0000000000000..d2b6b696327b9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm @@ -0,0 +1,91 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8U8KernelAvx512Vnni.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX512VNNI instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE QgemmU8U8KernelAvx512Common.inc +INCLUDE AssembleAvx512Vnni.inc + .list + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r10 - Supplies the length in bytes of a row from matrix A. +; +; zmm16-zmm27 - Supplies the block accumulators. +; + +ComputeBlock MACRO ColumnCount, RowCount + + vpmovzxbw zmm28,YMMWORD PTR [rdx] +IF ColumnCount EQ 32 + vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8] + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, +ELSE + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, +ENDIF + + ENDM + +; +; Generate the GEMM kernel. +; + +GemmU8U8KernelAvx512Function Avx512Vnni + + END diff --git a/onnxruntime/core/mlas/lib/arm64/sgemma.asm b/onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm similarity index 91% rename from onnxruntime/core/mlas/lib/arm64/sgemma.asm rename to onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm index 0b6eb11fa2d78..3675689db6cd5 100644 --- a/onnxruntime/core/mlas/lib/arm64/sgemma.asm +++ b/onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm @@ -6,7 +6,7 @@ ; ; Module Name: ; -; sgemma.asm +; SgemmKernelNeon.asm ; ; Abstract: ; @@ -19,31 +19,6 @@ TEXTAREA -; -; ComputeEffectiveAddress -; -; Generates the code to compute the effective address of a matrix element using -; the instruction template: -; -; add $DestReg,$BaseReg,$IndexReg lsl #2 -; -; For native ARM64, the macro generates a 64-bit address calculation. For CHPE -; targets, the macro generates a 32-bit address calculation to stay within the -; WOW64 sandbox. -; - - - MACRO - ComputeEffectiveAddress $DestReg, $BaseReg, $IndexReg - -#if defined(_CHPE_X86_ARM64_) - DCD 0x0B000800:OR:(:RCONST:$DestReg):OR:((:RCONST:$BaseReg):SHL:5):OR:((:RCONST:$IndexReg):SHL:16) -#else - DCD 0x8B000800:OR:(:RCONST:$DestReg):OR:((:RCONST:$BaseReg):SHL:5):OR:((:RCONST:$IndexReg):SHL:16) -#endif - - MEND - ; ; ClearRowAccumulators ; @@ -171,11 +146,11 @@ ClearBlockAccumulators $Columns, $Rows IF $Rows >= 2 - ComputeEffectiveAddress x10,x0,x6 ; compute matrix A plus 1 row + add x10,x0,x6 lsl #2 ; compute matrix A plus 1 row ENDIF IF $Rows >= 4 - ComputeEffectiveAddress x11,x10,x6 ; compute matrix A plus 2 rows - ComputeEffectiveAddress x12,x11,x6 ; compute matrix A plus 3 rows + add x11,x10,x6 lsl #2 ; compute matrix A plus 2 rows + add x12,x11,x6 lsl #2 ; compute matrix A plus 3 rows ENDIF sub x9,x3,#4 ; decrement block count to process @@ -217,7 +192,7 @@ $Mode.Compute$Columns.x$Rows.BlockBy1Loop ldp v6,v7,[x1,#-8*4] ENDIF MultiplyAccumulateBlock $Columns,$Rows,0 - sub x9,x9,1 + sub x9,x9,#1 cbnz x9,$Mode.Compute$Columns.x$Rows.BlockBy1Loop $Mode.Output$Columns.x$Rows.Block @@ -476,9 +451,9 @@ $Mode.OutputRemaining1x$Rows.Block PROLOG_SAVE_REG_PAIR d8,d9,#-32! PROLOG_SAVE_REG_PAIR d10,d11,#16 - ComputeEffectiveAddress x13,x2,x7 ; compute matrix C plus 1 row - ComputeEffectiveAddress x14,x13,x7 ; compute matrix C plus 2 rows - ComputeEffectiveAddress x15,x14,x7 ; compute matrix C plus 3 rows + add x13,x2,x7 lsl #2 ; compute matrix C plus 1 row + add x14,x13,x7 lsl #2 ; compute matrix C plus 2 rows + add x15,x14,x7 lsl #2 ; compute matrix C plus 3 rows mov x8,x0 ; save matrix A ; diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index 12fd0a368d8ae..c1f4a7e6c2821 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -29,7 +29,7 @@ Module Name: // Bundles the constants for use by kernels written in assembly. // -extern "C" const struct { +MLAS_INTERNAL_DATA const struct { float ErfUpperAbsRange; float ErfSplitBoundary; float ErfSMALL_P0; diff --git a/onnxruntime/core/mlas/lib/logistic.cpp b/onnxruntime/core/mlas/lib/logistic.cpp index 03061bb6bafbd..9e657f1892cc4 100644 --- a/onnxruntime/core/mlas/lib/logistic.cpp +++ b/onnxruntime/core/mlas/lib/logistic.cpp @@ -26,7 +26,7 @@ Module Name: // Bundles the floating point constants for use by kernels written in assembly. // -extern "C" const struct { +MLAS_INTERNAL_DATA const struct { float LowerRange; float UpperRange; float alpha_9; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b191c155928d9..1ae227d1f9ac1 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -16,7 +16,6 @@ Module Name: --*/ #pragma once -// clang-format off #include #include @@ -56,6 +55,18 @@ Module Name: #define MLAS_FORCEINLINE __attribute__ ((always_inline)) inline #endif +// +// Macro to tag globals as internal data shared with kernels written in +// assembly. These globals are marked with having hidden visibility to avoid +// needing to access the data through the global object table. +// + +#if defined(_MSC_VER) +#define MLAS_INTERNAL_DATA extern "C" +#else +#define MLAS_INTERNAL_DATA extern "C" __attribute ((visibility("hidden"))) +#endif + // // Macro to suppress unreferenced parameter warnings. // @@ -69,7 +80,7 @@ Module Name: #if defined(_M_AMD64) || defined(__x86_64__) #define MLAS_TARGET_AMD64 #endif -#if (defined(_M_IX86) && !defined(_M_HYBRID_X86_ARM64)) || defined(__i386__) +#if defined(_M_IX86) || defined(__i386__) #define MLAS_TARGET_IX86 #endif #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_IX86) @@ -92,8 +103,6 @@ Module Name: #if defined(_OPENMP) #include -#elif defined(_WIN32) -#define MLAS_USE_WIN32_THREADPOOL #endif // @@ -164,6 +173,52 @@ void typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE; +typedef +void +(MLASCALL MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE)( + int16_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumVector, + int16_t offb + ); + +typedef MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE* PMLAS_GEMM_U8U8_COPY_PACKA_ROUTINE; + +typedef +void +(MLASCALL MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE)( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumVector, + int16_t offa + ); + +typedef MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE* PMLAS_GEMM_U8U8_COPY_PACKB_ROUTINE; + +typedef +size_t +(MLASCALL MLAS_GEMM_U8U8_KERNEL)( + const int16_t* A, + const uint8_t* B, + int32_t* C, + size_t PairedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + int32_t DepthValue, + bool ZeroMode + ); + +typedef MLAS_GEMM_U8U8_KERNEL* PMLAS_GEMM_U8U8_KERNEL; + typedef void (MLASCALL MLAS_CONV_FLOAT_KERNEL)( @@ -291,6 +346,19 @@ extern "C" { MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Avx; #endif +#if defined(MLAS_TARGET_AMD64_IX86) + MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackASse; + MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBSse; + MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelSse; +#if defined(MLAS_TARGET_AMD64) + MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackAAvx2; + MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBAvx2; + MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx2; + MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx512BW; + MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx512Vnni; +#endif +#endif + #if defined(MLAS_TARGET_AMD64) MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelSse; MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelSse; @@ -406,6 +474,9 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) PMLAS_SGEMM_KERNEL_ROUTINE KernelZeroRoutine; PMLAS_SGEMM_KERNEL_ROUTINE KernelAddRoutine; + PMLAS_GEMM_U8U8_COPY_PACKA_ROUTINE GemmU8U8CopyPackARoutine; + PMLAS_GEMM_U8U8_COPY_PACKB_ROUTINE GemmU8U8CopyPackBRoutine; + PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel; #endif #if defined(MLAS_TARGET_AMD64) @@ -423,10 +494,6 @@ struct MLAS_PLATFORM { uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; #endif - -#if defined(MLAS_USE_WIN32_THREADPOOL) - int32_t MaximumThreadCount; -#endif }; extern MLAS_PLATFORM MlasPlatform; @@ -462,13 +529,11 @@ MlasGetMaximumThreadCount( MLAS_UNREFERENCED_PARAMETER(ThreadPool); #else if (ThreadPool != nullptr) { - return ThreadPool->NumThreads(); + return ThreadPool->NumThreads() + 1; } #endif -#if defined(MLAS_USE_WIN32_THREADPOOL) - return MlasPlatform.MaximumThreadCount; -#elif _OPENMP +#if defined(_OPENMP) return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1; #else return 1; @@ -495,7 +560,7 @@ MlasGetMaximumThreadCount( #if defined(MLAS_TARGET_ARM) #define MLAS_NEON_INTRINSICS #define MLAS_NEON32_INTRINSICS -#elif defined(MLAS_TARGET_ARM64) || defined(_M_HYBRID_X86_ARM64) +#elif defined(MLAS_TARGET_ARM64) #define MLAS_NEON_INTRINSICS #define MLAS_NEON64_INTRINSICS #elif defined(MLAS_TARGET_AMD64_IX86) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4f99d50fb27b0..1d0fdacae19d5 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -86,6 +86,9 @@ Return Value: this->KernelZeroRoutine = MlasSgemmKernelZeroSse; this->KernelAddRoutine = MlasSgemmKernelAddSse; + this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackASse; + this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBSse; + this->GemmU8U8Kernel = MlasGemmU8U8KernelSse; #if defined(MLAS_TARGET_AMD64) this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse; @@ -157,6 +160,10 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackAAvx2; + this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBAvx2; + this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; + if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { this->KernelZeroRoutine = MlasSgemmKernelZeroAvx512F; @@ -171,6 +178,23 @@ Return Value: this->NchwcBlockSize = 16; this->PreferredBufferAlignment = 64; + // + // Check if the processor supports AVX512BW. + // + + if ((Cpuid7[1] & 0x40000000) != 0) { + + this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512BW; + + // + // Check if the processor supports AVX512VNNI. + // + + if ((Cpuid7[2] & 0x800) != 0) { + this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Vnni; + } + } + } else { this->KernelZeroRoutine = MlasSgemmKernelZeroFma3; @@ -192,25 +216,6 @@ Return Value: } #endif - -#if defined(MLAS_USE_WIN32_THREADPOOL) - - // - // Retrieve the number of processors in the system. - // - - SYSTEM_INFO SystemInfo; - - GetSystemInfo(&SystemInfo); - - if (SystemInfo.dwNumberOfProcessors <= MLAS_MAXIMUM_THREAD_COUNT) { - this->MaximumThreadCount = int32_t(SystemInfo.dwNumberOfProcessors); - } else { - this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; - } - -#endif - } size_t @@ -223,7 +228,7 @@ MlasGetPreferredBufferAlignment( Routine Description: This routine returns the preferred byte alignment for buffers that are used - with this library. Buffers that are not bytes aligned to this value will + with this library. Buffers that are not byte aligned to this value will function, but will not achieve best performance. Arguments: diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp new file mode 100644 index 0000000000000..c5d07da984769 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -0,0 +1,599 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm.cpp + +Abstract: + + This module implements the quantized integer matrix/matrix multiply + operation (QGEMM). + +--*/ + +#include "mlasi.h" + +// +// Define the default strides to step through slices of the input matrices. +// + +#define MLAS_GEMM_U8U8_STRIDEM 12 +#define MLAS_GEMM_U8U8_STRIDEN 128 +#define MLAS_GEMM_U8U8_STRIDEK 128 + +#ifdef MLAS_TARGET_AMD64_IX86 + +void +MLASCALL +MlasGemmU8U8CopyPackASse( + int16_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumVector, + int16_t offb + ) +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D - Supplies the address of the destination packed buffer. + + A - Supplies the address of the source matrix. + + lda - Supplies the number of elements per row of the source matrix. + + CountM - Supplies the number of rows of the source matrix to copy. + + CountK - Supplies the number of columns of the source matrix to copy. + + RowSumVector - Supplies the address of the buffer to receive the sums of + the elements from each of the rows. Each sum has also been multiplied + by the zero point offset. + + offb - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + const __m128i OffsetBroadcast = _mm_set1_epi16(offb); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + __m128i RowSum = ZeroVector; + + // + // Zero extend the source bytes to 16-bits and write to the packed + // buffer. The packed buffer has the same data ordering as the source + // bytes, but the stride is CountK aligned up to a multiple of 8 + // values. + // + // These 16-bit values are also accumulated into an intermediate per-row + // accumulator. CountK cannot be greater than 256 to avoid overflowing + // these 16-bit accumulators. + // + + while (k >= 8) { + + __m128i Bytes = _mm_loadl_epi64((__m128i*)&a[0]); + __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); + + RowSum = _mm_add_epi16(RowSum, Words); + + _mm_storeu_si128((__m128i*)&D[0], Words); + + D += 8; + a += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + __m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData); + __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); + + RowSum = _mm_add_epi16(RowSum, Words); + + // + // Copy the 16-bit pairs from the vector to the destination packed + // buffer. Rotate the vector at each iteration. + // + + for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { + *((int32_t*)D) = _mm_cvtsi128_si32(Words); + D += 2; + Words = _mm_shuffle_epi32(Words, _MM_SHUFFLE(0, 3, 2, 1)); + } + } + + // + // Reduce the sum for the single row of output. + // + + RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast); + RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(3, 2, 3, 2))); + RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(0, 1, 0, 1))); + + *RowSumVector++ = _mm_cvtsi128_si32(RowSum); + + A += lda; + CountM -= 1; + } +} + +void +MLASCALL +MlasGemmU8U8CopyPackBSse( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumVector, + int16_t offa + ) +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D (rcx) - Supplies the address of the destination packed buffer. + + B (rdx) - Supplies the address of the source matrix. + + ldb (r8) - Supplies the number of elements per row of the source matrix. + + CountN (r9) - Supplies the number of columns of the source matrix to copy. + + CountK - Supplies the number of rows of the source matrix to copy. + + ColumnSumVector - Supplies the address of the buffer to receive the sums of + the elements from each of the columns. Each sum has also been multiplied + by the zero point offset. + + offa - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + const __m128i OffsetBroadcast = _mm_set1_epi16(offa); + uint8_t PaddedMatrixBData[16] = { 0 }; + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSum0 = ZeroVector; + __m128i ColumnSum1 = ZeroVector; + + while (k >= 2) { + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); + __m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector)); + ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector)); + ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); + + b += ldb * 2; + D += 16; + } + + ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast); + ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast); + + _mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0); + _mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1); + + ColumnSumVector += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSum0 = ZeroVector; + __m128i ColumnSum1 = ZeroVector; + + while (k >= 2) { + + // + // Copy the remaining columns to the zero padded stack buffer. + // + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]); + __m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[8]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector)); + ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + // + // Copy the remaining columns to the zero padded stack buffer. + // + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector)); + ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); + } + + ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast); + ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast); + + _mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0); + _mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1); + } +} + +size_t +MLASCALL +MlasGemmU8U8KernelSse( + const int16_t* A, + const uint8_t* B, + int32_t* C, + size_t PairedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + int32_t DepthValue, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8U8CopyPackASse. + + B - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8U8CopyPackBSse. + + C - Supplies the address of matrix C. + + PairedCountK - Supplies the number of paired columns from matrix A and + the number of paired rows from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to iterate + over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrixA multplied by the zero point offset of matrix B. This value is + accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + // + // Initialize the accumulators with the sum of the global depth value + // constant, the column sums, and the row sums. + // + + __m128i Accumulator0 = _mm_set1_epi32(DepthValue); + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_set1_epi32(RowSumVector[0])); + __m128i Accumulator1 = Accumulator0; + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&ColumnSumVector[0])); + Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&ColumnSumVector[4])); + ColumnSumVector += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the zero-extended pair of 16-bit values from matrix B, and add + // the 32-bit intermediate into the accumulator registers. + // + + const int16_t* a = A; + size_t k = PairedCountK; + + while (k > 0) { + + __m128i AElements0 = _mm_set1_epi32(*((int32_t*)a)); + __m128i BElements0 = _mm_loadu_si128((__m128i*)&B[0]); + + __m128i Intermediate0 = _mm_unpacklo_epi8(BElements0, ZeroVector); + __m128i Intermediate1 = _mm_unpackhi_epi8(BElements0, ZeroVector); + + Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0); + Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0); + + Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0); + Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1); + + a += 2; + B += 16; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0])); + Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&C[4])); + } + + _mm_storeu_si128((__m128i*)&C[0], Accumulator0); + _mm_storeu_si128((__m128i*)&C[4], Accumulator1); + + C += 8; + CountN -= 8; + + } else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0])); + } + + _mm_storeu_si128((__m128i*)&C[0], Accumulator0); + C += 4; + + Accumulator0 = Accumulator1; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadl_epi64((__m128i*)&C[0])); + } + + _mm_storel_epi64((__m128i*)&C[0], Accumulator0); + C += 2; + + Accumulator0 = _mm_shuffle_epi32(Accumulator0, _MM_SHUFFLE(1, 0, 3, 2)); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulator0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + break; + } + } + + return 1; +} + +void +MLASCALL +MlasQgemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const uint8_t* B, + size_t ldb, + uint8_t offb, + int32_t* C, + size_t ldc, + MLAS_THREADPOOL* ThreadPool + ) +{ + MLAS_DECLSPEC_ALIGN(int16_t PanelA[MLAS_GEMM_U8U8_STRIDEM * MLAS_GEMM_U8U8_STRIDEK], 64); + MLAS_DECLSPEC_ALIGN(uint8_t PanelB[MLAS_GEMM_U8U8_STRIDEN * MLAS_GEMM_U8U8_STRIDEK], 64); + + MLAS_DECLSPEC_ALIGN(int32_t RowSumVector[MLAS_GEMM_U8U8_STRIDEM], 16); + MLAS_DECLSPEC_ALIGN(int32_t ColumnSumVector[MLAS_GEMM_U8U8_STRIDEN], 16); + + size_t StrideM = MLAS_GEMM_U8U8_STRIDEM; + size_t StrideN = MLAS_GEMM_U8U8_STRIDEN; + size_t StrideK = MLAS_GEMM_U8U8_STRIDEK; + + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + size_t CountK; + + for (size_t k = 0; k < K; k += CountK) { + + CountK = StrideK; + + if (CountK > (K - k)) { + CountK = K - k; + } + + size_t CountN; + + for (size_t n = 0; n < N; n += CountN) { + + CountN = StrideN; + + if (CountN > (N - n)) { + CountN = N - n; + } + + MlasPlatform.GemmU8U8CopyPackBRoutine(PanelB, B + n + k * ldb, ldb, CountN, CountK, ColumnSumVector, -int16_t(offa)); + + size_t CountM; + + for (size_t m = 0; m < M; m += CountM) { + + CountM = StrideM; + + if (CountM > (M - m)) { + CountM = M - m; + } + + MlasPlatform.GemmU8U8CopyPackARoutine(PanelA, A + k + m * lda, lda, CountM, CountK, RowSumVector, -int16_t(offb)); + + int16_t* pa = PanelA; + int32_t* c = C + n + m * ldc; + + int32_t* RowSums = RowSumVector; + + size_t RowsRemaining = CountM; + size_t RowsHandled; + + size_t PairedCountK = (CountK + 1) / 2; + + while (RowsRemaining > 0) { + + RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairedCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0); + + RowsRemaining -= RowsHandled; + c += ldc * RowsHandled; + pa += 2 * PairedCountK * RowsHandled; + RowSums += RowsHandled; + } + } + } + } +} + +#endif diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index f436f17e6cd9c..5250c84487b93 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -55,7 +55,7 @@ struct MLAS_SGEMM_WORK_BLOCK { // Stores a vector to build a conditional load/store mask for vmaskmovps. // -extern "C" MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 8 * sizeof(float)) = { 0, 1, 2, 3, 4, 5, 6, 7 }; +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 8 * sizeof(float)) = { 0, 1, 2, 3, 4, 5, 6, 7 }; #endif diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 430afdf60e225..2fbeaef3d9815 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -26,7 +26,7 @@ Module Name: // Bundles the floating point constants for use by kernels written in assembly. // -extern "C" const struct { +MLAS_INTERNAL_DATA const struct { float LowerRange; float UpperRange; float alpha_13; diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index 858b72722e8bc..ef30de9499bb2 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -16,59 +16,6 @@ Module Name: #include "mlasi.h" -#if defined(MLAS_USE_WIN32_THREADPOOL) - -// -// Define the parameters to execute threaded work using the Windows thread pool -// library. -// - -struct MLAS_THREADED_WORK_BLOCK { - volatile LONG Counter; - PMLAS_THREADED_ROUTINE ThreadedRoutine; - void* Context; -}; - -void -CALLBACK -MlasThreadedWorkCallback( - PTP_CALLBACK_INSTANCE Instance, - void* Context, - PTP_WORK WorkObject - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute one iteration of a - batch of threaded work. - -Arguments: - - Instance - Supplies the callback instance object. - - Context - Supplies the pointer to the parameters for the operation. - - WorkObject - Supplies the threadpool work object. - -Return Value: - - None. - ---*/ -{ - MLAS_UNREFERENCED_PARAMETER(Instance); - MLAS_UNREFERENCED_PARAMETER(WorkObject); - - MLAS_THREADED_WORK_BLOCK* WorkBlock = (MLAS_THREADED_WORK_BLOCK*)Context; - - LONG Index = InterlockedIncrement(&WorkBlock->Counter) - 1; - - WorkBlock->ThreadedRoutine(WorkBlock->Context, Index); -} - -#endif - void MlasExecuteThreaded( MLAS_THREADED_ROUTINE ThreadedRoutine, @@ -99,48 +46,11 @@ MlasExecuteThreaded( } #endif -#if defined(MLAS_USE_WIN32_THREADPOOL) // - // Schedule the threaded iterations using a work object. + // Fallback to OpenMP or a serialized implementation. // - MLAS_THREADED_WORK_BLOCK WorkBlock; - - PTP_WORK WorkObject = CreateThreadpoolWork(MlasThreadedWorkCallback, &WorkBlock, nullptr); - - if (WorkObject != nullptr) { - - WorkBlock.Counter = 0; - WorkBlock.ThreadedRoutine = ThreadedRoutine; - WorkBlock.Context = Context; - - for (int32_t tid = 1; tid < Iterations; tid++) { - SubmitThreadpoolWork(WorkObject); - } - - // - // Execute the remaining iteration on this thread. - // - - ThreadedRoutine(Context, Iterations - 1); - - // - // Wait for the work object callbacks to complete. - // - - WaitForThreadpoolWorkCallbacks(WorkObject, FALSE); - CloseThreadpoolWork(WorkObject); - - return; - } - - // - // Fallback to a serialized implementation. - // - -#endif - // // Execute the routine for the specified number of iterations. // diff --git a/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h b/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h new file mode 100644 index 0000000000000..bd3112bd9ccd9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h @@ -0,0 +1,238 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleAvx512Vnni.h + +Abstract: + + This module contains macros to build VNNI instructions for toolchains that + do not natively support this newer instruction set extension. + +--*/ + +// +// Map friendly register names to the encoded register index. +// + + .equ .LZmmIndex_zmm0, 0 + .equ .LZmmIndex_zmm1, 1 + .equ .LZmmIndex_zmm2, 2 + .equ .LZmmIndex_zmm3, 3 + .equ .LZmmIndex_zmm4, 4 + .equ .LZmmIndex_zmm5, 5 + .equ .LZmmIndex_zmm6, 6 + .equ .LZmmIndex_zmm7, 7 + .equ .LZmmIndex_zmm8, 8 + .equ .LZmmIndex_zmm9, 9 + .equ .LZmmIndex_zmm10, 10 + .equ .LZmmIndex_zmm11, 11 + .equ .LZmmIndex_zmm12, 12 + .equ .LZmmIndex_zmm13, 13 + .equ .LZmmIndex_zmm14, 14 + .equ .LZmmIndex_zmm15, 15 + .equ .LZmmIndex_zmm16, 16 + .equ .LZmmIndex_zmm17, 17 + .equ .LZmmIndex_zmm18, 18 + .equ .LZmmIndex_zmm19, 19 + .equ .LZmmIndex_zmm20, 20 + .equ .LZmmIndex_zmm21, 21 + .equ .LZmmIndex_zmm22, 22 + .equ .LZmmIndex_zmm23, 23 + .equ .LZmmIndex_zmm24, 24 + .equ .LZmmIndex_zmm25, 25 + .equ .LZmmIndex_zmm26, 26 + .equ .LZmmIndex_zmm27, 27 + .equ .LZmmIndex_zmm28, 28 + .equ .LZmmIndex_zmm29, 29 + .equ .LZmmIndex_zmm30, 30 + .equ .LZmmIndex_zmm31, 31 + + .equ .LGprIndex_rax, 0 + .equ .LGprIndex_rcx, 1 + .equ .LGprIndex_rdx, 2 + .equ .LGprIndex_rbx, 3 + .equ .LGprIndex_rbp, 5 + .equ .LGprIndex_rsi, 6 + .equ .LGprIndex_rdi, 7 + .equ .LGprIndex_r8, 8 + .equ .LGprIndex_r9, 9 + .equ .LGprIndex_r10, 10 + .equ .LGprIndex_r11, 11 + .equ .LGprIndex_r12, 12 + .equ .LGprIndex_r13, 13 + .equ .LGprIndex_r14, 14 + .equ .LGprIndex_r15, 15 + +/*++ + +Macro Description: + + This macro builds a VNNI instruction of the form: + + instr zmm1,zmm2,zmm3 + +Arguments: + + Opcode - Specifies the opcode for the VNNI instruction. + + DestReg - Specifies the destination register. + + Src1Reg - Specifies the first source register. + + Src2Reg - Specifies the second source register. + +--*/ + + .macro VnniZmmZmmZmm Opcode, DestReg, Src1Reg, Src2Reg + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + ((((.LZmmIndex_\Src2Reg\() >> 4) & 1) ^ 1) << 6) + .set Payload0, Payload0 + ((((.LZmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) + .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 4) & 1) ^ 1) << 4) + + .set Payload1, 0x05 # "66" prefix + .set Payload1, Payload1 + (((.LZmmIndex_\Src1Reg\() & 15) ^ 15) << 3) + + .set Payload2, 0x40 # 512-bit vector length + .set Payload2, Payload2 + ((((.LZmmIndex_\Src1Reg\() >> 4) & 1) ^ 1) << 3) + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LZmmIndex_\DestReg\() & 7) << 3) + .set ModRMByte, ModRMByte + (.LZmmIndex_\Src2Reg\() & 7) + + .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte + + .endm + + .macro VpdpbusdZmmZmmZmm DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 0x50, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + .macro VpdpbusdsZmmZmmZmm DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 0x51, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + .macro VpdpwssdZmmZmmZmm DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 0x52, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + .macro VpdpwssdsZmmZmmZmm DestReg, Src1Reg, Src2Reg + + VnniZmmZmmZmm 0x53, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + +/*++ + +Macro Description: + + This macro builds a VNNI instruction of the form: + + instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale]{1to16} + +Arguments: + + Opcode - Specifies the opcode for the VNNI instruction. + + DestReg - Specifies the destination register. + + Src1Reg - Specifies the first source register. + + BaseReg - Specifies the base register of the broadcast operand. + + IndexReg - Specifies the optional index register of the broadcast operand. + + Scale - Specifies the scaling factor of the optional index register. + +--*/ + + .macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) +.ifnes "\IndexReg\()", "" + .set Payload0, Payload0 + ((((.LGprIndex_\IndexReg\() >> 3) & 1) ^ 1) << 6) +.else + .set Payload0, Payload0 + 0x40 # zero logical index register +.endif + .set Payload0, Payload0 + ((((.LGprIndex_\BaseReg\() >> 3) & 1) ^ 1) << 5) + .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 4) & 1) ^ 1) << 4) + + .set Payload1, 0x05 # "66" prefix + .set Payload1, Payload1 + (((.LZmmIndex_\Src1Reg\() & 15) ^ 15) << 3) + + .set Payload2, 0x50 # 512-bit vector length, broadcast + .set Payload2, Payload2 + ((((.LZmmIndex_\Src1Reg\() >> 4) & 1) ^ 1) << 3) + + .set ModRMByte, 0x00 # memory form + .set ModRMByte, ModRMByte + ((.LZmmIndex_\DestReg\() & 7) << 3) +.ifnes "\IndexReg\()", "" + .set ModRMByte, ModRMByte + 0x04 # indicate SIB byte needed +.else + .set ModRMByte, ModRMByte + (.LGprIndex_\BaseReg\() & 7) +.endif + +.ifnes "\IndexReg\()", "" + .set SibByte, 0 +.ifeqs "\Scale\()", "2" + .set SibByte, SibByte + (1 << 6) +.else +.ifeqs "\Scale\()", "4" + .set SibByte, SibByte + (2 << 6) +.else +.ifeqs "\Scale\()", "8" + .set SibByte, SibByte + (3 << 6) +.else +.ifnes "\Scale\()", "1" + .err +.endif +.endif +.endif +.endif + .set SibByte, SibByte + ((.LGprIndex_\IndexReg\() & 7) << 3) + .set SibByte, SibByte + (.LGprIndex_\BaseReg\() & 7) +.endif + +.ifnes "\IndexReg\()", "" + .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte, SibByte +.else + .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte +.endif + + .endm + + .macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + + .endm + + .macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + + .endm + + .macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + + .endm + + .macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + + VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + + .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S index 29518fb91119a..92b7976d7db79 100644 --- a/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S @@ -26,6 +26,7 @@ Abstract: // // Structure layout for the erf constants block. // + .equ ErfUpperAbsRange, 0 .equ ErfSplitBoundary, 4 .equ ErfSMALL_P0, 8 @@ -68,7 +69,7 @@ Abstract: .equ ErfBuffer1, 128 .equ ErfKernelFrame_CountN, 256 .equ ErfKernelFrame_ReturnAddress, 256+8 - + /*++ Routine Description: @@ -92,7 +93,7 @@ Return Value: .globl C_UNDERSCORE(MlasErfKernelFma3) C_UNDERSCORE(MlasErfKernelFma3): sub rsp,ErfKernelFrame_ReturnAddress - mov rax,C_UNDERSCORE(MlasErfConstants)@GOTPCREL[rip] + lea rax,C_UNDERSCORE(MlasErfConstants)[rip] sub rdx,8*4 jb .LErfProcessRemainingCount @@ -376,10 +377,9 @@ C_UNDERSCORE(MlasErfKernelFma3): .LErfProcess1x8: mov DWORD PTR ErfKernelFrame_CountN[rsp],edx - mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm3,DWORD PTR ErfKernelFrame_CountN[rsp] - vpcmpgtd ymm3,ymm3,YMMWORD PTR [rcx] + vpcmpgtd ymm3,ymm3,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vbroadcastss ymm15,ErfNegZero[rax] vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] # original input vx0 diff --git a/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S index 8b7b27dcbb0ed..243b355398eb6 100644 --- a/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S @@ -72,7 +72,7 @@ Return Value: .globl C_UNDERSCORE(MlasLogisticKernelFma3) C_UNDERSCORE(MlasLogisticKernelFma3): - mov rax,C_UNDERSCORE(MlasLogisticConstants)@GOTPCREL[rip] + lea rax,C_UNDERSCORE(MlasLogisticConstants)[rip] vbroadcastss ymm4,LogisticConstants_LowerRange[rax] vbroadcastss ymm5,LogisticConstants_UpperRange[rax] vbroadcastss ymm6,LogisticConstants_alpha_9[rax] @@ -120,9 +120,8 @@ C_UNDERSCORE(MlasLogisticKernelFma3): add rdx,8 # correct for over-subtract above jz .LExitKernel mov DWORD PTR LogisticKernelFrame_CountN[rsp],edx - mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm2,DWORD PTR LogisticKernelFrame_CountN[rsp] - vpcmpgtd ymm2,ymm2,YMMWORD PTR [rcx] + vpcmpgtd ymm2,ymm2,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] vmaxps ymm0,ymm4,ymm0 # clamp lower bound vminps ymm0,ymm5,ymm0 # clamp upper bound diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S new file mode 100644 index 0000000000000..8837be62c5e2e --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S @@ -0,0 +1,1121 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8U8KernelAvx2.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX2 instructions. + +--*/ + +#include "asmmacro.h" + + .intel_syntax noprefix + + .text + +// +// Stack frame layout for the U8U8 CopyPackA routine. +// + + .equ .LGemmU8U8CopyPackAFrame_PaddedMatrixAData, -72 + .equ .LGemmU8U8CopyPackAFrame_mask, -8 + .equ .LGemmU8U8CopyPackAFrame_SavedR13, 0 + .equ .LGemmU8U8CopyPackAFrame_SavedR12, 8 + .equ .LGemmU8U8CopyPackAFrame_SavedRbx, 16 + .equ .LGemmU8U8CopyPackAFrame_SavedRbp, 24 + .equ .LGemmU8U8CopyPackAFrame_ReturnAddress, 32 + .equ .LGemmU8U8CopyPackAFrame_offb, 40 + +// +// Stack frame layout for the U8U8 CopyPackB routine. +// + + .equ .LGemmU8U8CopyPackBFrame_PaddedMatrixBData, -40 + .equ .LGemmU8U8CopyPackBFrame_Padding, -8 + .equ .LGemmU8U8CopyPackBFrame_SavedRbx, 0 + .equ .LGemmU8U8CopyPackBFrame_SavedRbp, 8 + .equ .LGemmU8U8CopyPackBFrame_ReturnAddress, 16 + .equ .LGemmU8U8CopyPackBFrame_offa, 24 + +// +// Stack frame layout for the U8U8 kernel. +// + + .equ .LGemmU8U8KernelFrame_mask, -8 + .equ .LGemmU8U8KernelFrame_SavedR14, 0 + .equ .LGemmU8U8KernelFrame_SavedR13, 8 + .equ .LGemmU8U8KernelFrame_SavedR12, 16 + .equ .LGemmU8U8KernelFrame_SavedRbx, 24 + .equ .LGemmU8U8KernelFrame_SavedRbp, 32 + .equ .LGemmU8U8KernelFrame_ReturnAddress, 40 + .equ .LGemmU8U8KernelFrame_ldc, 48 + .equ .LGemmU8U8KernelFrame_RowSumVector, 56 + .equ .LGemmU8U8KernelFrame_ColumnSumVector, 64 + .equ .LGemmU8U8KernelFrame_DepthValue, 72 + .equ .LGemmU8U8KernelFrame_ZeroMode, 80 + +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + + The kernel expects that elements from matrix A have been zero extended to + 16-bits and padded to a multiple of 32-bits (two pairs of 16-bit values). + The kernel can then efficiently broadcast 32-bits from the packed buffer + and avoid expensive shuffling inside the kernel. + +Arguments: + + D (rdi) - Supplies the address of the destination packed buffer. + + A (rsi) - Supplies the address of the source matrix. + + lda (rdx) - Supplies the number of elements per row of the source matrix. + + CountM (rcx) - Supplies the number of rows of the source matrix to copy. + + CountK (r8) - Supplies the number of columns of the source matrix to copy. + + RowSumVector (r9) - Supplies the address of the buffer to receive the sums + of the elements from each of the rows. Each sum has also been multiplied + by the zero point offset. + + offb - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2) +C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): + + push rbp + push rbx + push r12 + push r13 + + mov r10,rdx + mov r11,rcx + lea r12,[r8+1] + and r12,NOT 1 # align CountK up to pair count + vpbroadcastw xmm8,WORD PTR .LGemmU8U8CopyPackAFrame_offb[rsp] + +// +// Compute the conditional load/store mask for an unaligned CountK. +// + + mov eax,r8d + and eax,15 # isolate unaligned count + inc eax + shr eax,1 # align unaligned count to pair count + mov DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp],eax + vpbroadcastd ymm9,DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp] + vpcmpgtd ymm9,ymm9,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + +// +// Zero initialize the padded stack buffers. +// + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 + vmovdqu YMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 + +// +// Process 4 rows of matrix A in a loop. +// +// For each row, zero extend the source bytes to 16-bits and write to the packed +// buffer. The packed buffer has the same data ordering as the source bytes, but +// the stride is CountK aligned up to an even number of 16-bit values. +// +// These 16-bit values are also accumulated into an intermediate per-row +// accumulator. CountK cannot be greater than 256 to avoid overflowing these +// 16-bit accumulators. +// + + sub r11,4 + jb .LCopyPackA.ProcessRemainingRows + +.LCopyPackA.ProcessNextRowM4: + vpxor xmm0,xmm0,xmm0 # clear row accumulators + vpxor xmm1,xmm1,xmm1 + vpxor xmm2,xmm2,xmm2 + vpxor xmm3,xmm3,xmm3 + mov rdx,rsi + mov rcx,rdi + lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows + lea rdi,[rdi+r12*(2*4)] # advance next matrix D by 4 rows + mov rbx,r8 # reload columns remaining + sub rbx,16 + jb .LCopyPackA.ProcessRemainingColumnsM4 + +.LCopyPackA.ProcessNextColumnLoopM4: + lea rax,[rdx+r10*2] # compute matrix A plus two rows + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vpmovzxbw ymm5,XMMWORD PTR [rdx+r10] + vpmovzxbw ymm6,XMMWORD PTR [rax] + vpmovzxbw ymm7,XMMWORD PTR [rax+r10] + lea rax,[rcx+r12*4] # compute matrix D plus two rows + vmovdqu YMMWORD PTR [rcx],ymm4 + vmovdqu YMMWORD PTR [rcx+r12*2],ymm5 + vmovdqu YMMWORD PTR [rax],ymm6 + vmovdqu YMMWORD PTR [rax+r12*2],ymm7 + vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns + vpaddw ymm1,ymm1,ymm5 + vpaddw ymm2,ymm2,ymm6 + vpaddw ymm3,ymm3,ymm7 + add rdx,16 # advance matrix A by 16 bytes + add rcx,16*2 # advance matrix D by 16 words + sub rbx,16 # subtract columns remaining + jae .LCopyPackA.ProcessNextColumnLoopM4 + +.LCopyPackA.ProcessRemainingColumnsM4: + add rbx,16 # correct for over-subtract above + jz .LCopyPackA.ReduceRowSumVectorM4 + +// +// Copy the unaligned CountK columns to a zero padded stack buffer. +// + + lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] + test bl,8 # (CountK & 8) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan8M4 + lea r13,[rdx+r10*2] # compute matrix A plus two rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r10] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [r13] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [r13+r10] + mov QWORD PTR [rbp+48],rax + add rdx,8 + add rbp,8 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan8M4: + test bl,4 # (CountK & 4) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan4M4 + lea r13,[rdx+r10*2] # compute matrix A plus two rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r10] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [r13] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [r13+r10] + mov DWORD PTR [rbp+48],eax + add rdx,4 + add rbp,4 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan4M4: + test bl,2 # (CountK & 2) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan2M4 + lea r13,[rdx+r10*2] # compute matrix A plus two rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r10] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [r13] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [r13+r10] + mov WORD PTR [rbp+48],ax + add rdx,2 + add rbp,2 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan2M4: + test bl,1 # (CountK & 1) != 0? + jz .LCopyPackA.ProcessPaddedMatrixADataM4 + lea r13,[rdx+r10*2] # compute matrix A plus two rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r10] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [r13] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [r13+r10] + mov BYTE PTR [rbp+48],al + +// +// Process the remaining CountK columns using the zero padded stack buffer. +// + +.LCopyPackA.ProcessPaddedMatrixADataM4: + vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] + vpmovzxbw ymm5,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+16] + vpmovzxbw ymm6,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32] + vpmovzxbw ymm7,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+48] + lea rax,[rcx+r12*4] # compute matrix D plus two rows + vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 + vpmaskmovd YMMWORD PTR [rcx+r12*2],ymm9,ymm5 + vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 + vpmaskmovd YMMWORD PTR [rax+r12*2],ymm9,ymm7 + vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns + vpaddw ymm1,ymm1,ymm5 + vpaddw ymm2,ymm2,ymm6 + vpaddw ymm3,ymm3,ymm7 + +// +// Reduce the sums for the four rows of output. Transpose the intermediate +// accumulators by treating the registers as 32-bit elements containing a pair +// of 16-bit sums. Continue reducing the transposed accumulators to produce the +// final 32-bit vector output. +// + +.LCopyPackA.ReduceRowSumVectorM4: + vpunpckldq ymm4,ymm0,ymm1 # [A5 B5 A4 B4 A1 B1 A0 B0] + vpunpckhdq ymm5,ymm0,ymm1 # [A7 B7 A6 B6 A3 B3 A2 B2] + vpunpckldq ymm6,ymm2,ymm3 # [C5 D5 C4 D4 C1 D1 C0 D0] + vpunpckhdq ymm7,ymm2,ymm3 # [C7 D7 C6 D6 C3 D3 C2 D2] + vpunpcklqdq ymm0,ymm4,ymm6 # [A4 B4 C4 D4 A0 B0 C0 D0] + vpunpckhqdq ymm1,ymm4,ymm6 # [A5 B5 C5 D5 A1 B1 C1 D1] + vpunpcklqdq ymm2,ymm5,ymm7 # [A6 B6 C6 D6 A2 B2 C2 D2] + vpunpckhqdq ymm3,ymm5,ymm7 # [A7 B7 C7 D7 A3 B3 C3 D3] + vpaddw ymm0,ymm0,ymm1 # reduction + vpaddw ymm0,ymm0,ymm2 + vpaddw ymm0,ymm0,ymm3 + vextracti128 xmm1,ymm0,1 # extract high pairs + vpaddw xmm0,xmm0,xmm1 # reduction + vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce + vmovdqu XMMWORD PTR [r9],xmm0 + add r9,4*4 # advance row sum vector by 4 dwords + sub r11,4 # subtract rows remaining + jae .LCopyPackA.ProcessNextRowM4 + +.LCopyPackA.ProcessRemainingRows: + add r11,4 # correct for over-subtract above + jz .LCopyPackA.ExitRoutine + +// +// Process a single row of matrix A in a loop. +// + +.LCopyPackA.ProcessNextRowM1: + vpxor xmm0,xmm0,xmm0 # clear row accumulator + mov rdx,rsi + mov rcx,rdi + add rsi,r10 + lea rdi,[rdi+r12*2] + mov rbx,r8 # reload columns remaining + sub rbx,16 + jb .LCopyPackA.ProcessRemainingColumnsM1 + +.LCopyPackA.ProcessNextColumnLoopM1: + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rcx],ymm4 + vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns + add rdx,16 # advance matrix A by 16 bytes + add rcx,16*2 # advance matrix D by 16 words + sub rbx,16 # subtract columns remaining + jae .LCopyPackA.ProcessNextColumnLoopM1 + +.LCopyPackA.ProcessRemainingColumnsM1: + add rbx,16 # correct for over-subtract above + jz .LCopyPackA.ReduceRowSumVectorM1 + +// +// Copy the unaligned CountK columns to a zero padded stack buffer. +// + + lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] + test bl,8 # (CountK & 8) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan8M1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 + add rbp,8 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan8M1: + test bl,4 # (CountK & 4) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan4M1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 + add rbp,4 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan4M1: + test bl,2 # (CountK & 2) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan2M1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 + add rbp,2 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan2M1: + test bl,1 # (CountK & 1) != 0? + jz .LCopyPackA.ProcessPaddedMatrixADataM1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +// +// Process the remaining CountK columns using the zero padded stack buffer. +// + +.LCopyPackA.ProcessPaddedMatrixADataM1: + vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] + vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 + vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns + +// +// Reduce the sum for the single row of output. +// + +.LCopyPackA.ReduceRowSumVectorM1: + vextracti128 xmm1,ymm0,1 # extract high pairs + vpaddw xmm0,xmm0,xmm1 # reduction + vphaddw xmm0,xmm0,xmm0 + vphaddw xmm0,xmm0,xmm0 + vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce + vmovd DWORD PTR [r9],xmm0 + add r9,4 # advance row sum vector by 1 DWORD + dec r11 # decrement rows remaining + jnz .LCopyPackA.ProcessNextRowM1 + +// +// Restore non-volatile registers and return. +// + +.LCopyPackA.ExitRoutine: + vzeroupper + + pop r13 + pop r12 + pop rbx + pop rbp + ret + +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D (rdi) - Supplies the address of the destination packed buffer. + + B (rsi) - Supplies the address of the source matrix. + + ldb (rdx) - Supplies the number of elements per row of the source matrix. + + CountN (rcx) - Supplies the number of columns of the source matrix to copy. + + CountK (r8) - Supplies the number of rows of the source matrix to copy. + + ColumnSumVector (r9) - Supplies the address of the buffer to receive the sums + of the elements from each of the columns. Each sum has also been + multiplied by the zero point offset. + + offa - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2) +C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): + + push rbp + push rbx + + mov r10,rdx + mov r11,rcx + vpbroadcastw ymm5,WORD PTR .LGemmU8U8CopyPackBFrame_offa[rsp] + +// +// Zero initialize the padded stack buffers. +// + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp],ymm0 + +// +// Process 16 columns of matrix B in a loop. +// + + sub r11,16 + jb .LCopyPackB.ProcessRemainingColumns + +.LCopyPackB.ProcessNextColumnN16: + vpxor xmm0,xmm0,xmm0 # clear column accumulators + vpxor xmm1,xmm1,xmm1 + mov rdx,rsi + add rsi,16 # advance next matrix B by 16 columns + mov rbx,r8 # reload rows remaining + sub rbx,2 + jb .LCopyPackB.ProcessRemainingRowsN16 + +.LCopyPackB.ProcessNextRowLoopN16: + vmovdqu xmm2,XMMWORD PTR [rdx] # load two rows + vmovdqu xmm3,XMMWORD PTR [rdx+r10] + lea rdx,[rdx+r10*2] # advance matrix B by two rows + vpunpcklbw xmm4,xmm2,xmm3 # interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows + vmovdqu XMMWORD PTR [rdi+16],xmm3 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + add rdi,32 # advance matrix D by 32 bytes + vpaddw ymm0,ymm0,ymm4 # accumulate per column + vpaddw ymm1,ymm1,ymm3 + sub rbx,2 # subtract columns remaining + jae .LCopyPackB.ProcessNextRowLoopN16 + +.LCopyPackB.ProcessRemainingRowsN16: + add rbx,2 # correct for over-subtract above + jz .LCopyPackB.ReduceColumnSumVectorN16 + vpmovzxbw ymm4,XMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows + vextracti128 xmm3,ymm4,1 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 # accumulate per column + vpaddw ymm1,ymm1,ymm3 + add rdi,32 # advance matrix D by 32 bytes + +.LCopyPackB.ReduceColumnSumVectorN16: + vpmaddwd ymm0,ymm0,ymm5 # multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm5 # multiply by offset and reduce + vmovdqu YMMWORD PTR [r9],ymm0 + vmovdqu YMMWORD PTR [r9+32],ymm1 + add r9,64 # advance column sum vector by 16 dwords + sub r11,16 # subtract columns remaining + jae .LCopyPackB.ProcessNextColumnN16 + +.LCopyPackB.ProcessRemainingColumns: + add r11,16 # correct for over-subtract above + jnz .LCopyPackB.ProcessColumnNUnaligned + +// +// Restore non-volatile registers and return. +// + +.LCopyPackB.ExitRoutine: + vzeroupper + + pop rbx + pop rbp + ret + +// +// Process the remaining columns of matrix B. +// + +.LCopyPackB.ProcessColumnNUnaligned: + vpxor xmm0,xmm0,xmm0 # clear column accumulators + vpxor xmm1,xmm1,xmm1 + sub r8,2 + jb .LCopyPackB.ProcessRemainingRowsNUnaligned + +.LCopyPackB.ProcessNextRowLoopNUnaligned: + mov rdx,rsi + lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] + test r11b,8 # (CountN & 8) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan8K2 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r10] + mov QWORD PTR [rbp+16],rax + add rdx,8 # advance matrix B + add rbp,8 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan8K2: + test r11b,4 # (CountN & 4) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan4K2 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r10] + mov DWORD PTR [rbp+16],eax + add rdx,4 # advance matrix B + add rbp,4 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan4K2: + test r11b,2 # (CountN & 2) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan2K2 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r10] + mov WORD PTR [rbp+16],ax + add rdx,2 # advance matrix B + add rbp,2 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan2K2: + test r11b,1 # (CountN & 1) != 0? + jz .LCopyPackB.ProcessPaddedMatrixBDataK2 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r10] + mov BYTE PTR [rbp+16],al + +.LCopyPackB.ProcessPaddedMatrixBDataK2: + vmovdqu xmm2,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] + vmovdqu xmm3,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp+16] + vpunpcklbw xmm4,xmm2,xmm3 # interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows + vmovdqu XMMWORD PTR [rdi+16],xmm3 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 # accumulate per column + vpaddw ymm1,ymm1,ymm3 + lea rsi,[rsi+r10*2] # advance next matrix B by two rows + add rdi,32 # advance matrix D by 32 bytes + sub r8,2 # subtract columns remaining + jae .LCopyPackB.ProcessNextRowLoopNUnaligned + +.LCopyPackB.ProcessRemainingRowsNUnaligned: + add r8,2 + jz .LCopyPackB.ReduceColumnSumVectorNUnaligned + mov rdx,rsi + lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] + test r11b,8 # (CountN & 8) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan8K1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 # advance matrix B + add rbp,8 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan8K1: + test r11b,4 # (CountN & 4) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan4K1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 # advance matrix B + add rbp,4 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan4K1: + test r11b,2 # (CountN & 2) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan2K1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 # advance matrix B + add rbp,2 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan2K1: + test r11b,1 # (CountN & 1) != 0? + jz .LCopyPackB.ProcessPaddedMatrixBDataK1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +.LCopyPackB.ProcessPaddedMatrixBDataK1: + vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] + vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows + vextracti128 xmm3,ymm4,1 + vpmovzxbw ymm4,xmm4 + vpmovzxbw ymm3,xmm3 + vpaddw ymm0,ymm0,ymm4 # accumulate per column + vpaddw ymm1,ymm1,ymm3 + +.LCopyPackB.ReduceColumnSumVectorNUnaligned: + vpmaddwd ymm0,ymm0,ymm5 # multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm5 # multiply by offset and reduce + vmovdqu YMMWORD PTR [r9],ymm0 + vmovdqu YMMWORD PTR [r9+32],ymm1 + jmp .LCopyPackB.ExitRoutine + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 16). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + ymm0 - Supplies the first vector loaded from matrix B. + + ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount + is 16). + + ymm2 - Supplies the broadcast value loaded from matrix A. + +--*/ + + .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg + +.if \ColumnCount\() == 16 + vpmaddwd ymm3,ymm2,ymm0 + vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 + vpmaddwd ymm2,ymm2,ymm1 + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 +.else + vpmaddwd ymm3,ymm2,ymm0 + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + r10 - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+r10+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+r10*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+r10+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15" + + .endm + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the number of paired columns from matrix A and the number of + paired rows from matrix B to iterate over. + + r10 - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the sum of the global depth value constant, +// the column sums, and the row sums. +// + + vpbroadcastd ymm1,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp] +.if \ColumnCount\() == 16 + vpaddd ymm0,ymm1,YMMWORD PTR [r13] + vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] + add r13,16*4 # advance ColumnSumVector by 16 columns +.else + vpaddd ymm1,ymm1,YMMWORD PTR [r13] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" + +// +// Iterate over PairedCountK elements from matrix A and matrix B. +// +// Unrolling the loop to do two iterations improves performance slightly at the +// cost of larger code size. Balance this by only unrolling for the common case +// of computing 16 columns for an even number of rows. +// + + mov rbp,rcx # reload PairedCountK +.if \RowCount\() > 3 + lea rbx,[r10*2+r10] + add rbx,rdi # compute matrix A plus 3 rows +.endif + +.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0) + sub rbp,2 + jb .LProcessRemainingBlocks.\ColumnCount\().\RowCount\() + +.LComputeBlockLoop.\ColumnCount\().\RowCount\(): + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 + ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4 + add rdi,2*4 # advance matrix A by 2 pairs +.if \RowCount\() > 3 + add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs +.endif + add rsi,2*32 # advance matrix B by 64 columns + sub rbp,2 # subtract pairs remaining + jae .LComputeBlockLoop.\ColumnCount\().\RowCount\() + +.LProcessRemainingBlocks.\ColumnCount\().\RowCount\(): + add rbp,2 # correct for over-subtract above + jz .LComputeBlockLoopExit.\ColumnCount\().\RowCount\() + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 + add rsi,32 # advance matrix B by 32 columns +.else +.LComputeBlockLoop.\ColumnCount\().\RowCount\(): + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 pair +.if \RowCount\() > 3 + add rbx,4 # advance matrix A plus 3 rows by 1 pair +.endif + add rsi,32 + dec rbp # decrement pairs remaining + jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\() +.endif + +.LComputeBlockLoopExit.\ColumnCount\().\RowCount\(): +.if \RowCount\() > 3 + lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows + add rbx,rax +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + r11 - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the number of paired columns from matrix A and the number of + paired rows from matrix B to iterate over. + + r10 - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + + r14b - Supplies the zero mode flag. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough + + cmp r9,8 + jbe .LProcessRemainingCountN.\RowCount\() + +.LProcessNextColumnLoop16xN.\RowCount\(): + ProduceOutputBlock 16, \RowCount\() + sub r9,16 + jb .LOutputMasked16xNBlock.\RowCount\() + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlock.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]" + +.LSkipAccumulateOutput16xNBlock.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,r11 # reload matrix A + cmp r9,8 + ja .LProcessNextColumnLoop16xN.\RowCount\() + test r9,r9 + jz .LExitKernel + +.LProcessRemainingCountN.\RowCount\(): + ProduceOutputBlock 8, \RowCount\() + cmp r9,8 + jb .LOutputMasked8xNBlock.\RowCount\() + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutput8xNBlock.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput8xNBlock.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15" + jmp .LExitKernel + +.LOutputMasked16xNBlock.\RowCount\(): + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutputMasked16xNBlock.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutputMasked16xNBlock.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" + add rdx,8*4 # advance matrix C by 8 columns +.if \RowCount\() > 3 + add rbx,8*4 # advance matrix C plus 3 rows by 8 columns +.endif + add r9,8 # correct for over-subtract above + +.LOutputMasked8xNBlock.\RowCount\(): + mov DWORD PTR .LGemmU8U8KernelFrame_mask[rsp],r9d + vpbroadcastd ymm0,DWORD PTR .LGemmU8U8KernelFrame_mask[rsp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutputMasked8xNBlock.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" + +.LSkipAccumulateOutputMasked8xNBlock.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15" +.ifb \Fallthrough\() + jmp .LExitKernel +.endif + + .endm + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8U8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8U8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + PairedCountK (rcx) - Supplies the number of paired columns from matrix A and + the number of paired rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrixA multplied by the zero point offset of matrix B. This value is + accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8U8KernelAvx2) +C_UNDERSCORE(MlasGemmU8U8KernelAvx2): + + push rbp + push rbx + push r12 + push r13 + push r14 + + mov rax,.LGemmU8U8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + lea r10,[rcx*4] + mov r11,rdi + mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp] + mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp] + movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp] + +// +// Process CountM rows of the matrices. +// + + cmp r8,5 + ja .LProcessCountM6 + je .LProcessCountM5 + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + ProcessCountM 4 + +.LProcessCountM6: + mov r8d,6 # return 6 rows handled + ProcessCountM 6, Fallthrough + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + mov eax,r8d + vzeroupper + + pop r14 + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + +.LProcessCountM5: + ProcessCountM 5 + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S new file mode 100644 index 0000000000000..bacb29a9a138c --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S @@ -0,0 +1,120 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8U8KernelAvx512BW.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX512BW instructions. + +--*/ + +#include "asmmacro.h" +#include "QgemmU8U8KernelAvx512Common.h" + + .intel_syntax noprefix + + .text + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 32). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + zmm28 - Supplies the first vector loaded from matrix B. + + zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount + is 32). + + zmm30 - Supplies the broadcast value loaded from matrix A. + +--*/ + + .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg + +.if \ColumnCount\() == 32 + vpmaddwd zmm31,zmm30,zmm28 + vpaddd \Vec1Reg\(),\Vec1Reg\(),zmm31 + vpmaddwd zmm30,zmm30,zmm29 + vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm30 +.else + vpmaddwd zmm31,zmm30,zmm28 + vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm31 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + r10 - Supplies the length in bytes of a row from matrix A. + + zmm16-zmm27 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount + + vpmovzxbw zmm28,YMMWORD PTR [rsi] + EmitIfCountGE \ColumnCount\(), 32, "vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), zmm16, zmm17" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), zmm18, zmm19" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), zmm20, zmm21" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), zmm22, zmm23" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), zmm24, zmm25" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), zmm26, zmm27" + + .endm + +// +// Generate the GEMM kernel. +// + +GemmU8U8KernelAvx512Function Avx512BW + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h new file mode 100644 index 0000000000000..3abd87b7ce986 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h @@ -0,0 +1,361 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8U8KernelAvx512Common.h + +Abstract: + + This module contains common kernel macros and structures for the quantized + integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and + AVX512VNNI kernels. + +--*/ + +// +// Stack frame layout for the U8U8 kernel. +// + + .equ .LGemmU8U8KernelFrame_SavedR14, 0 + .equ .LGemmU8U8KernelFrame_SavedR13, 8 + .equ .LGemmU8U8KernelFrame_SavedR12, 16 + .equ .LGemmU8U8KernelFrame_SavedRbx, 24 + .equ .LGemmU8U8KernelFrame_SavedRbp, 32 + .equ .LGemmU8U8KernelFrame_ReturnAddress, 40 + .equ .LGemmU8U8KernelFrame_ldc, 48 + .equ .LGemmU8U8KernelFrame_RowSumVector, 56 + .equ .LGemmU8U8KernelFrame_ColumnSumVector, 64 + .equ .LGemmU8U8KernelFrame_DepthValue, 72 + .equ .LGemmU8U8KernelFrame_ZeroMode, 80 + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the number of paired columns from matrix A and the number of + paired rows from matrix B to iterate over. + + r10 - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the sum of the global depth value constant, +// the column sums, and the row sums. +// + + vpbroadcastd zmm31,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp] +.if \ColumnCount\() == 32 + vpaddd zmm30,zmm31,ZMMWORD PTR [r13] + vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64] + add r13,32*4 # advance ColumnSumVector by 32 columns +.else + vpaddd zmm31,zmm31,ZMMWORD PTR [r13] +.endif + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm16,zmm30,DWORD PTR [r12]{1to16}" + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17,zmm31,DWORD PTR [r12]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm18,zmm30,DWORD PTR [r12+4]{1to16}" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19,zmm31,DWORD PTR [r12+4]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm20,zmm30,DWORD PTR [r12+8]{1to16}" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21,zmm31,DWORD PTR [r12+8]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm22,zmm30,DWORD PTR [r12+12]{1to16}" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm31,DWORD PTR [r12+12]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm30,DWORD PTR [r12+16]{1to16}" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25,zmm31,DWORD PTR [r12+16]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm26,zmm30,DWORD PTR [r12+20]{1to16}" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27,zmm31,DWORD PTR [r12+20]{1to16}" + +// +// Iterate over PairedCountK elements from matrix A and matrix B. +// + + mov rbp,rcx # reload PairedCountK +.if \RowCount\() > 3 + lea rbx,[r10*2+r10] + add rbx,rdi # compute matrix A plus 3 rows +.endif + +.LComputeBlockLoop.\ColumnCount\().\RowCount\(): + ComputeBlock \ColumnCount\(), \RowCount\() + add rdi,4 # advance matrix A by 1 pair +.if \RowCount\() > 3 + add rbx,4 # advance matrix A plus 3 rows by 1 pair +.endif + add rsi,32 + dec rbp # decrement pairs remaining + jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\() + +.if \RowCount\() > 3 + lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows + add rbx,rax +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + r11 - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the number of paired columns from matrix A and the number of + paired rows from matrix B to iterate over. + + r10 - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + + r14b - Supplies the zero mode flag. + +--*/ + + .macro ProcessCountM RowCount + + cmp r9,16 + jbe .LProcessRemainingCountN.\RowCount\() + +.LProcessNextColumnLoop32xN.\RowCount\(): + ProduceOutputBlock 32, \RowCount\() + lea rsi,[rsi+r10*8] # advance matrix B by 8*PairedCountK + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutput32xNBlock.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm16,zmm16,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm18,zmm18,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput32xNBlock.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm16" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm18" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm20" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm22" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26" + add rdx,16*4 # advance matrix C by 16 columns +.if \RowCount\() > 3 + add rbx,16*4 # advance matrix C plus 3 rows by 16 columns +.endif + sub r9,16 + +.LOutput16xNBlock.\RowCount\(): + sub r9,16 + jae .LOutput16xNBlockWithMask.\RowCount\() + lea rcx,[r9+16] # correct for over-subtract above + mov ebp,1 + shl ebp,cl + dec ebp + kmovw k1,ebp # update mask for remaining columns + xor r9,r9 # no more columns remaining + +.LOutput16xNBlockWithMask.\RowCount\(): + test r14b,r14b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlockWithMask.\RowCount\() + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21{k1},zmm21,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput16xNBlockWithMask.\RowCount\(): + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm17" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm19" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm21" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,r11 # reload matrix A + cmp r9,16 + ja .LProcessNextColumnLoop32xN.\RowCount\() + test r9,r9 + jz .LExitKernel + +.LProcessRemainingCountN.\RowCount\(): + ProduceOutputBlock 16, \RowCount\() + jmp .LOutput16xNBlock.\RowCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates the common AVX512 code for the inner kernel to compute + matrix multiplication. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro GemmU8U8KernelAvx512Function Isa + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8U8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8U8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + PairedCountK (rcx) - Supplies the number of paired columns from matrix A and + the number of paired rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrixA multplied by the zero point offset of matrix B. This value is + accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()) +C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()): + + push rbp + push rbx + push r12 + push r13 + push r14 + + mov rax,.LGemmU8U8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + lea r10,[rcx*4] + mov r11,rdi + mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp] + mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp] + movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp] + mov ebp,-1 + kmovw k1,ebp # update mask to write all columns + +// +// Process CountM rows of the matrices. +// + + cmp r8,5 + ja .LProcessCountM6 + je .LProcessCountM5 + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + ProcessCountM 4 + +.LProcessCountM6: + mov r8d,6 # return 6 rows handled + ProcessCountM 6 + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + mov eax,r8d + + pop r14 + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + +.LProcessCountM5: + ProcessCountM 5 + + .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S new file mode 100644 index 0000000000000..76a85427d5689 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S @@ -0,0 +1,95 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8U8KernelAvx512Vnni.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX512VNNI instructions. + +--*/ + +#include "asmmacro.h" +#include "QgemmU8U8KernelAvx512Common.h" +#include "AssembleAvx512Vnni.h" + + .intel_syntax noprefix + + .text + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + r10 - Supplies the length in bytes of a row from matrix A. + + zmm16-zmm27 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount + + vpmovzxbw zmm28,YMMWORD PTR [rsi] +.if \ColumnCount\() == 32 + vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8] + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]" + EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]" + EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]" + EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]" + EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]" + EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30" + EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30" +.else + EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm17,zmm28,rdi" + EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm19,zmm28,rdi,r10,1" + EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm21,zmm28,rdi,r10,2" + EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx" + EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1" + EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2" +.endif + + .endm + +// +// Generate the GEMM kernel. +// + +GemmU8U8KernelAvx512Function Avx512Vnni + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S index 617119763a2b5..2163708dcb352 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S +++ b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S @@ -257,9 +257,15 @@ Arguments: .macro PostProcessBlock FilterCount, OutputCount .globl MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\() +#endif MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\(): .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() +#endif MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): .if \FilterCount\() > 2 diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S index 873cb4dbf9431..55d2aa613f212 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S +++ b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S @@ -361,6 +361,9 @@ Arguments: .macro PostProcessBlock FilterCount, OutputCount .globl MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\() +#endif MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\(): .if \FilterCount\() > 2 diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S index e5505ea48942e..4dbbf696e96f7 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S +++ b/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S @@ -249,6 +249,9 @@ Arguments: .macro PostProcessBlock FilterCount, OutputCount .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#endif MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): .if \FilterCount\() > 2 diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S index 0147f08edc821..63c6d5d2c837e 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S +++ b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S @@ -374,10 +374,9 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Avx): .L\Mode\().OutputMasked8x4Block: vmovd xmm0,r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vshufps xmm0,xmm0,xmm0,0 - vpcmpgtd xmm1,xmm0,XMMWORD PTR [rbp+16] - vpcmpgtd xmm0,xmm0,XMMWORD PTR [rbp] + vpcmpgtd xmm1,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] + vpcmpgtd xmm0,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vinsertf128 ymm0,ymm0,xmm1,1 .ifeqs "\Mode\()","Add" vmaskmovps ymm8,ymm0,YMMWORD PTR [rdx] @@ -473,10 +472,9 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Avx): .L\Mode\().OutputMasked8x2Block: vmovd xmm0,r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vshufps xmm0,xmm0,xmm0,0 - vpcmpgtd xmm1,xmm0,XMMWORD PTR [rbp+16] - vpcmpgtd xmm0,xmm0,XMMWORD PTR [rbp] + vpcmpgtd xmm1,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] + vpcmpgtd xmm0,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vinsertf128 ymm0,ymm0,xmm1,1 .ifeqs "\Mode\()","Add" vmaskmovps ymm8,ymm0,YMMWORD PTR [rdx] @@ -540,10 +538,9 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Avx): .L\Mode\().OutputMasked8x1Block: vmovd xmm0,r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vshufps xmm0,xmm0,xmm0,0 - vpcmpgtd xmm1,xmm0,XMMWORD PTR [rbp+16] - vpcmpgtd xmm0,xmm0,XMMWORD PTR [rbp] + vpcmpgtd xmm1,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] + vpcmpgtd xmm0,xmm0,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vinsertf128 ymm0,ymm0,xmm1,1 .ifeqs "\Mode\()","Add" vmaskmovps ymm8,ymm0,YMMWORD PTR [rdx] diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S index cfeceb6be30f3..a7382f897946b 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S @@ -435,9 +435,8 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Fma3): .L\Mode\().OutputMasked8x6Block: mov DWORD PTR [rsp+SgemmKernelFrame_mask],r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm0,DWORD PTR [rsp+SgemmKernelFrame_mask] - vpcmpgtd ymm0,ymm0,YMMWORD PTR [rbp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] .ifeqs "\Mode\()","Add" vmaskmovps ymm4,ymm0,YMMWORD PTR [rdx] vmaskmovps ymm6,ymm0,YMMWORD PTR [rdx+rax] @@ -550,9 +549,8 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Fma3): .L\Mode\().OutputMasked8x3Block: mov DWORD PTR [rsp+SgemmKernelFrame_mask],r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm0,DWORD PTR [rsp+SgemmKernelFrame_mask] - vpcmpgtd ymm0,ymm0,YMMWORD PTR [rbp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] .ifeqs "\Mode\()","Add" vmaskmovps ymm4,ymm0,YMMWORD PTR [rdx] vmaskmovps ymm6,ymm0,YMMWORD PTR [rdx+rax] @@ -653,9 +651,8 @@ C_UNDERSCORE(MlasSgemmKernel\Mode\()Fma3): .L\Mode\().OutputMasked8x1Block: mov DWORD PTR [rsp+SgemmKernelFrame_mask],r9d - mov rbp,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm0,DWORD PTR [rsp+SgemmKernelFrame_mask] - vpcmpgtd ymm0,ymm0,YMMWORD PTR [rbp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] .ifeqs "\Mode\()","Add" vmaskmovps ymm4,ymm0,YMMWORD PTR [rdx] vfmadd213ps ymm5,ymm2,ymm4 diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S index 28fca0e956640..86bc82b23071b 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S +++ b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S @@ -80,10 +80,9 @@ C_UNDERSCORE(MlasSgemmKernelM1Avx): mov eax,r8d and eax,7 vmovd xmm7,eax - mov rbx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR [rbx+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR [rbx] + vpcmpgtd xmm6,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] + vpcmpgtd xmm7,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vinsertf128 ymm7,ymm7,xmm6,1 // diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S index 8d5ff17f90084..86bc9209fa248 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S +++ b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S @@ -79,10 +79,9 @@ C_UNDERSCORE(MlasSgemmKernelM1TransposeBAvx): mov eax,ecx and eax,7 vmovd xmm7,eax - mov rbx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR [rbx+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR [rbx] + vpcmpgtd xmm6,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] + vpcmpgtd xmm7,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vinsertf128 ymm7,ymm7,xmm6,1 // diff --git a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S index 61bbef5c91171..dd5584648dbe7 100644 --- a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S @@ -72,7 +72,7 @@ Return Value: .globl C_UNDERSCORE(MlasTanhKernelFma3) C_UNDERSCORE(MlasTanhKernelFma3): - mov rax,C_UNDERSCORE(MlasTanhConstants)@GOTPCREL[rip] + lea rax,C_UNDERSCORE(MlasTanhConstants)[rip] vbroadcastss ymm4,TanhConstants_LowerRange[rax] vbroadcastss ymm5,TanhConstants_UpperRange[rax] vbroadcastss ymm6,TanhConstants_alpha_13[rax] @@ -116,9 +116,8 @@ C_UNDERSCORE(MlasTanhKernelFma3): add rdx,8 # correct for over-subtract above jz .LExitKernel mov DWORD PTR TanhKernelFrame_CountN[rsp],edx - mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] vbroadcastss ymm2,DWORD PTR TanhKernelFrame_CountN[rsp] - vpcmpgtd ymm2,ymm2,YMMWORD PTR [rcx] + vpcmpgtd ymm2,ymm2,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] vmaxps ymm0,ymm4,ymm0 # clamp lower bound vminps ymm0,ymm5,ymm0 # clamp upper bound diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index cd64d2398228b..fb14f762969d8 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "core/common/common.h" #include "core/common/status.h" @@ -9,7 +11,7 @@ #include "core/framework/mldata_type_utils.h" #include "core/framework/kernel_registry.h" #include "core/framework/fuse_nodes_funcs.h" -#include "core/common/callback.h" +#include "core/framework/callback.h" #include "core/optimizer/optimizer_execution_frame.h" namespace onnxruntime { diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index 41f5e85215699..cb507feb9a57a 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -11,7 +11,7 @@ #include "core/framework/execution_frame.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/ml_value.h" -#include "core/common/callback.h" +#include "core/framework/callback.h" namespace onnxruntime { class DataTransferManager; diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index c6f57900b9881..bdf44761bba18 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -52,7 +52,7 @@ class TransformerMemcpyImpl { std::string provider_; }; -/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer. +/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer. @param check_outer_scope If true and the graph is a subgraph, check parent graph/s for 'name' if not found in 'graph'. */ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) { @@ -73,7 +73,6 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr provider != onnxruntime::kMklDnnExecutionProvider && provider != onnxruntime::kNGraphExecutionProvider && provider != onnxruntime::kNupharExecutionProvider && - provider != onnxruntime::kTensorrtExecutionProvider && provider != onnxruntime::kOpenVINOExecutionProvider) { TransformerMemcpyImpl copy_impl(graph, provider); auto current_modified = copy_impl.ModifyGraph(registry_manager_); @@ -100,7 +99,7 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr Overview: The transformer transforms the input graph as follows: -(1) For every initializer W that is referenced by both provider and non-provider nodes, +(1) For every initializer W that is referenced by both provider and non-provider nodes, we create a duplicate initializer W2 and change all provider nodes to reference this duplicate copy. @@ -167,7 +166,9 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi } void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed) { - if (node.GetExecutionProviderType() == provider_) { + if (node.GetExecutionProviderType() == provider_ + || (node.GetExecutionProviderType() == kCudaExecutionProvider && provider_ == kTensorrtExecutionProvider) + || (node.GetExecutionProviderType() == kTensorrtExecutionProvider && provider_ == kCudaExecutionProvider)) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; @@ -206,7 +207,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } } else { // TODO: copy between devices? i.e. multiple GPUs - if (node.GetExecutionProviderType() != onnxruntime::kCpuExecutionProvider && node.GetExecutionProviderType() != onnxruntime::kTensorrtExecutionProvider && + if (node.GetExecutionProviderType() != onnxruntime::kCpuExecutionProvider && node.GetExecutionProviderType() != onnxruntime::kNGraphExecutionProvider && !node.GetExecutionProviderType().empty()) { ORT_THROW("Execution type '", node.GetExecutionProviderType(), "' doesn't support memcpy "); } diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index d27cdbaae6833..c9199faf7f168 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -24,7 +24,7 @@ limitations under the License. #include #include "core/common/common.h" -#include "core/common/callback.h" +#include "core/framework/callback.h" #include "core/platform/env_time.h" #ifndef _WIN32 diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 2b7e50dee38aa..2d34bea1e2bfc 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -41,7 +41,7 @@ namespace onnxruntime { namespace { constexpr int OneMillion = 1000000; -static void ORT_API_CALL DeleteBuffer(void* param) noexcept { ::free(param); } +static void DeleteBuffer(void* param) noexcept { ::free(param); } class UnmapFileParam { public: @@ -50,7 +50,7 @@ class UnmapFileParam { int fd; }; -static void ORT_API_CALL UnmapFile(void* param) noexcept { +static void UnmapFile(void* param) noexcept { UnmapFileParam* p = reinterpret_cast(param); int ret = munmap(p->addr, p->len); if (ret != 0) { @@ -124,7 +124,7 @@ class PosixEnv : public Env { } common::Status ReadFileAsString(const char* fname, off_t offset, void*& p, size_t& len, - OrtCallback& deleter) const override { + OrtCallback& deleter) const override { if (!fname) { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "ReadFileAsString: 'fname' cannot be NULL"); } @@ -180,7 +180,7 @@ class PosixEnv : public Env { char buf[1024]; const char* msg = ""; if (e > 0) { -#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined (__ANDROID__) +#if defined(__GLIBC__) && defined(_GNU_SOURCE) && !defined(__ANDROID__) msg = strerror_r(e, buf, sizeof(buf)); #else // for Mac OS X and Android lower than API 23 diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1f0e9bcac4410..010077d07b273 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -30,7 +30,7 @@ namespace onnxruntime { namespace { -static void ORT_API_CALL DeleteBuffer(void* param) noexcept { ::free(param); } +static void DeleteBuffer(void* param) noexcept { ::free(param); } class WindowsEnv : public Env { public: diff --git a/onnxruntime/core/providers/common.h b/onnxruntime/core/providers/common.h index 16a065cb8cc1a..c23dded2f9c40 100644 --- a/onnxruntime/core/providers/common.h +++ b/onnxruntime/core/providers/common.h @@ -4,6 +4,7 @@ #pragma once #include "core/common/common.h" +#include "core/framework/tensor.h" namespace onnxruntime { @@ -20,4 +21,25 @@ inline int64_t HandleNegativeAxis(int64_t axis, int64_t tensor_rank) { return axis = axis < 0 ? axis + tensor_rank : axis; } +/** +Returns true if given tensor is a scalar or 1D tensor of size 1 +**/ +inline bool IsScalarOr1ElementVector(const Tensor* input) { + if (input->Shape().NumDimensions() == 0 || + (input->Shape().NumDimensions() == 1 && input->Shape().GetDims().size() == 1)) { + return true; + } else { + return false; + } +} + +/** +Clamps input between provided min and max values +**/ +inline float clamp(float v, float lo, float hi) { + if (v < lo) return lo; + if (v > hi) return hi; + return v; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 3937ab4ed53f3..89dbf9f8d783d 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -257,7 +257,6 @@ Status LoopImpl::CreateFeedsFetchesManager(std::unique_ptr& feed_names.push_back(entry.first); } - FeedsFetchesInfo ffi(feed_names, subgraph_output_names_); auto status = FeedsFetchesManager::Create(feed_names, subgraph_output_names_, session_state_.GetOrtValueNameIdxMap(), ffm); diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc b/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc index 821c84a78c723..2c696066556e4 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc @@ -125,7 +125,6 @@ Status CreateFeedsFetchesManager(const GraphViewer& subgraph, int num_variadic_i feed_names.push_back(entry.first); } - FeedsFetchesInfo ffi(feed_names, subgraph_output_names); auto status = FeedsFetchesManager::Create(feed_names, subgraph_output_names, ort_value_name_idx_map, ffm); return status; diff --git a/onnxruntime/core/providers/cpu/controlflow/utils.h b/onnxruntime/core/providers/cpu/controlflow/utils.h index b5a39bcfdef4d..d3427e9104a50 100644 --- a/onnxruntime/core/providers/cpu/controlflow/utils.h +++ b/onnxruntime/core/providers/cpu/controlflow/utils.h @@ -26,7 +26,8 @@ common::Status SubgraphExecuteHelper(std::unique_ptr& cache } else { // use a local instance until we know we're successful, and cache if it is std::unique_ptr new_ffm; - impl.CreateFeedsFetchesManager(new_ffm); + ORT_RETURN_IF_ERROR(impl.CreateFeedsFetchesManager(new_ffm)); + status = impl.Execute(&*new_ffm, nullptr); if (status.IsOK()) { cached_feeds_fetches_manager = std::move(new_ffm); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 08b6a31938111..eb1c26203308f 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -9,12 +9,16 @@ #include "contrib_ops/cpu_contrib_kernels.h" #endif +#ifdef MICROSOFT_AUTOML +#include "automl_ops/cpu_automl_kernels.h" +#endif + #include "core/framework/compute_capability.h" namespace onnxruntime { // Forward declarations of op kernels -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Elu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, LeakyRelu); @@ -132,6 +136,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceLogSumExp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceMin); @@ -141,6 +146,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSumSquare); @@ -218,6 +224,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Mea class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN); @@ -232,6 +239,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_int32_t, OneHot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_float, OneHot); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -247,9 +256,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, NonZero); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where); // Opset 10 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer); @@ -263,9 +274,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QuantizeLinear); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, MatMulInteger); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, bool, Slice); @@ -288,9 +301,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence); +// opset 11 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, DynamicQuantizeLinear); + void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -408,6 +425,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -415,8 +433,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -494,6 +513,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -508,6 +528,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -523,9 +545,11 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 10 BuildKernelCreateInfo, @@ -539,9 +563,11 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -563,6 +589,10 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + //opset 11 + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -588,7 +618,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t_double, DictVectorizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, FeatureVectorizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, Imputer); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, LabelEncoder); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 1, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float, LinearClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, double, LinearClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t, LinearClassifier); @@ -615,6 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, TreeEnsembleRegressor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder); + void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -633,7 +671,7 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -659,6 +697,13 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -673,6 +718,9 @@ static void RegisterCPUKernels(KernelRegistry& kernel_registry) { #ifndef DISABLE_CONTRIB_OPS ::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry); #endif +#ifdef MICROSOFT_AUTOML + ::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry); +#endif } std::shared_ptr GetCpuKernelRegistry() { diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index fc8f7c5917873..a688241d71867 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -76,8 +76,6 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut static Status RandomNormalCompute(float mean, float scale, std::default_random_engine& generator, TensorProto::DataType dtype, Tensor& Y); static Status RandomUniformCompute(float high, float low, std::default_random_engine& generator, TensorProto::DataType dtype, Tensor& Y); -// Leaving in case we need to change to this approach -//static Status CreateOutputTensorFromTensorValues(OpKernelContext* ctx, const Tensor& X,Tensor** Y); static Status CreateOutputTensorFromTensorShape(OpKernelContext* ctx, const Tensor& X, Tensor** Y); static TensorProto::DataType InferDataType(const Tensor& tensor); @@ -168,53 +166,48 @@ static Status MultinomialCompute(OpKernelContext* ctx, Eigen::array Y_dims = {{batch_size, num_samples}}; Matrix output = Matrix(Y.template MutableData(), Y_dims); - // TODO (perf optimization) - the idea behind making this a lambda is so that we can parallelize across batches. - // When we do that this lamdba will act as one task given to a thread - auto DoWork = [ctx, num_samples, num_classes, &generator, &logits, &output](int64_t start_row, - int64_t limit_row) { - std::default_random_engine generator_copy = generator; - // BEGIN create temporary tensor - AllocatorPtr alloc; - ctx->GetTempSpaceAllocator(&alloc); - auto cdf_data = static_cast(alloc->Alloc(sizeof(double) * num_classes)); - BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc)); - Eigen::array cdf_dims = {{num_classes}}; - auto cdf = EigenVector(cdf_data, cdf_dims); - // END create temporary tensor - - std::uniform_real_distribution dist(0.0, 1.0); // TODO: should this be initialized per batch? - for (int64_t b = start_row; b < limit_row; ++b) { - const float* logits_row = &(logits(b, 0)); - // Takes an along-class maximum (for numerical stability). - float maxx = std::numeric_limits::lowest(); - for (int64_t j = 0; j < num_classes; ++j) { - if (Eigen::numext::isfinite(logits_row[j])) { - maxx = std::max(maxx, logits_row[j]); - } + // BEGIN create temporary tensor + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + auto cdf_data = static_cast(alloc->Alloc(sizeof(double) * num_classes)); + BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc)); + Eigen::array cdf_dims = {{num_classes}}; + auto cdf = EigenVector(cdf_data, cdf_dims); + // END create temporary tensor + + std::uniform_real_distribution dist(0.0, 1.0); // TODO: should this be initialized per batch? + + for (int64_t b = 0; b < batch_size; ++b) { + const float* logits_row = &(logits(b, 0)); + // Takes an along-class maximum (for numerical stability). + float maxx = std::numeric_limits::lowest(); + for (int64_t j = 0; j < num_classes; ++j) { + if (Eigen::numext::isfinite(logits_row[j])) { + maxx = std::max(maxx, logits_row[j]); } - const auto max_logit = static_cast(maxx); - - // Precompute cumulative probability distribution across classes. - // Note: This isn't normalized. - cdf = (logits.chip<0>(b).cast() - max_logit).exp(); - double running_total = 0; - for (int64_t j = 0; j < num_classes; ++j) { - if (Eigen::numext::isfinite(logits_row[j])) { - running_total += cdf(j); - } - cdf(j) = running_total; - } - // Generate each sample. - const double* cdf_begin = cdf.data(); - const double* cdf_end = cdf.data() + num_classes; - for (int64_t j = 0; j < num_samples; ++j) { - const double to_find = dist(generator_copy) * running_total; - auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find); - output(b, j) = static_cast(std::distance(cdf_begin, found_iter)); + } + const auto max_logit = static_cast(maxx); + + // Precompute cumulative probability distribution across classes. + // Note: This isn't normalized. + cdf = (logits.chip<0>(b).cast() - max_logit).exp(); + double running_total = 0; + for (int64_t j = 0; j < num_classes; ++j) { + if (Eigen::numext::isfinite(logits_row[j])) { + running_total += cdf(j); } + cdf(j) = running_total; + } + // Generate each sample. + const double* cdf_begin = cdf.data(); + const double* cdf_end = cdf.data() + num_classes; + for (int64_t j = 0; j < num_samples; ++j) { + const double to_find = dist(generator) * running_total; + auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find); + output(b, j) = static_cast(std::distance(cdf_begin, found_iter)); } - }; - DoWork(0, batch_size); + } + return Status::OK(); } @@ -262,32 +255,6 @@ Status Multinomial::Compute(OpKernelContext* ctx) const { return status; } -/* -alternative interpretation of the spec is that the input tensor contains the dimensions as ints. -Keeping this temporarily in case we go back to that. - -// read shape information from input tensor and create output tensor with it -static Status CreateOutputTensorFromTensorValues(OpKernelContext* ctx, const Tensor& X, Tensor** Y) { - const TensorShape& shape = X.Shape(); - auto size = shape.Size(); - auto num_dims = shape.NumDimensions(); - - if (num_dims != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Expected 1 dimension tensor with shape information. Dimensions=", num_dims); - } - - std::vector dims; - dims.reserve(shape.Size()); - - auto data = gsl::make_span(tensor.template Data(), shape.Size()); - dims.insert(dims.cbegin(), data.cbegin(), data.cend()); - - *Y = ctx->Output(0, TensorShape(dims)); - - return Status::OK(); -} -*/ - // create output tensor using shape of input tensor static Status CreateOutputTensorFromTensorShape(OpKernelContext* ctx, const Tensor& X, Tensor** Y) { const TensorShape& shape = X.Shape(); @@ -363,9 +330,11 @@ static Status RandomUniformCompute(float low, float high, template void GenerateData(std::default_random_engine& generator, TDistribution distribution, Tensor& tensor) { - auto out = gsl::make_span(tensor.template MutableData(), tensor.Shape().Size()); - - std::for_each(out.begin(), out.end(), [&generator, &distribution](T& value) { value = distribution(generator); }); + T* out = tensor.MutableData(); + for (int64_t i = 0, end = tensor.Shape().Size(); i < end; ++i) { + *out = distribution(generator); + ++out; + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 6ef8d2c460553..639341d1a29cc 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -20,11 +20,14 @@ class RandomNormal final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; ORT_ENFORCE(info.GetAttr("dtype", &dtype).IsOK()); @@ -60,11 +63,14 @@ class RandomNormalLike final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; if (info.GetAttr("dtype", &dtype).IsOK()) { @@ -94,11 +100,14 @@ class RandomUniform final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; ORT_ENFORCE(info.GetAttr("dtype", &dtype).IsOK()); @@ -131,11 +140,14 @@ class RandomUniformLike final : public OpKernel { ORT_ENFORCE(info.GetAttr("low", &low_).IsOK()); // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; if (info.GetAttr("dtype", &dtype).IsOK()) { @@ -163,11 +175,14 @@ class Multinomial final : public OpKernel { ORT_ENFORCE(info.GetAttr("sample_size", &num_samples_).IsOK()); float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t output_dtype_tmp; if (!info.GetAttr("dtype", &output_dtype_tmp).IsOK()) { diff --git a/onnxruntime/core/providers/cpu/math/clip.cc b/onnxruntime/core/providers/cpu/math/clip.cc index 160d587df7238..dc99582ddca04 100644 --- a/onnxruntime/core/providers/cpu/math/clip.cc +++ b/onnxruntime/core/providers/cpu/math/clip.cc @@ -5,9 +5,16 @@ namespace onnxruntime { -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Clip, 6, + 10, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), + Clip_6); + +ONNX_CPU_OPERATOR_KERNEL( + Clip, + 11, KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), Clip); diff --git a/onnxruntime/core/providers/cpu/math/clip.h b/onnxruntime/core/providers/cpu/math/clip.h index 653967547bb02..b4ef64398dddf 100644 --- a/onnxruntime/core/providers/cpu/math/clip.h +++ b/onnxruntime/core/providers/cpu/math/clip.h @@ -10,9 +10,9 @@ namespace onnxruntime { template -class Clip final : public OpKernel { +class Clip_6 final : public OpKernel { public: - Clip(const OpKernelInfo& info) : OpKernel(info) { + Clip_6(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttr("max", &max_).IsOK()); ORT_ENFORCE(info.GetAttr("min", &min_).IsOK()); } @@ -32,4 +32,36 @@ class Clip final : public OpKernel { T min_; }; +template +class Clip final : public OpKernel { + public: + Clip(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* ctx) const override { + const auto* X = ctx->Input(0); + const auto* min = ctx->Input(1); + const auto* max = ctx->Input(2); + Tensor* Y = ctx->Output(0, X->Shape()); + + auto min_val = -std::numeric_limits::infinity(); + auto max_val = std::numeric_limits::infinity(); + if (min) { + ORT_ENFORCE(min->Shape().NumDimensions() == 0, "min should be a scalar."); + min_val = *(min->template Data()); + } + if (max) { + ORT_ENFORCE(max->Shape().NumDimensions() == 0, "max should be a scalar."); + max_val = *(max->template Data()); + } + + EigenVectorMap(Y->template MutableData(), Y->Shape().Size()) = + ConstEigenVectorMap(X->template Data(), X->Shape().Size()) + .cwiseMax(min_val) + .cwiseMin(max_val); + + return Status::OK(); + } +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index ece68834e7525..7c13c98a745a6 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -18,6 +18,15 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); +#define REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + OP_TYPE, \ + VERSION, \ + TYPE, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + #define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ OP_TYPE, \ @@ -26,6 +35,15 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); +#define REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + OP_TYPE, \ + VERSION_FROM, VERSION_TO, \ + TYPE, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, double, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int32_t, Add); @@ -88,45 +106,55 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6); REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, float, Max_8); REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8); -REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less); -REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int32_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int64_t, Less); -REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) -REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater); -REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int64_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater); -REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, bool, Equal); -REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int32_t, Equal); -REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int64_t, Equal); -REG_ELEMENTWISE_TYPED_KERNEL(Equal, 11, float, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, bool, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int32_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int64_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, float, Equal); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6); REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8); REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf); +// REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Not, 1, bool, Not); +// REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(And, 7, bool, And); +// REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Or, 7, bool, Or); +// REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Xor, 7, bool, Xor); + ONNX_CPU_OPERATOR_KERNEL( Not, 1, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Not); ONNX_CPU_OPERATOR_KERNEL( And, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), And); ONNX_CPU_OPERATOR_KERNEL( Or, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Or); ONNX_CPU_OPERATOR_KERNEL( Xor, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Xor); template diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 035ed3ee69cfb..e9d28c8314adb 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -320,6 +320,11 @@ struct BroadcastIterator { return index; } + void Reserve(int64_t max_dims) { + deltas_.reserve(max_dims); + counts_.reserve(max_dims); + } + void Init(int64_t axis, int64_t largest) { ORT_ENFORCE(axis == 1 || axis == largest, "Attempting to broadcast an axis by a dimension other than 1. ", axis, " by ", largest); @@ -368,6 +373,8 @@ struct Broadcaster { size_t dimension_count_max = std::max(shape1.size(), shape2.size()); size_t dimension_count_min = std::min(shape1.size(), shape2.size()); output_shape_.resize(dimension_count_max); + iterator1_.Reserve(dimension_count_max); + iterator2_.Reserve(dimension_count_max); auto iter1 = shape1.end(); auto iter2 = shape2.end(); @@ -395,22 +402,22 @@ struct Broadcaster { *--output_shape = axis; } index++; // Manually increment since we processed one axis - } - - for (; index < dimension_count_min; index++) { - auto axis1 = *--iter1; - auto axis2 = *--iter2; + } else { + for (; index < dimension_count_min; index++) { + auto axis1 = *--iter1; + auto axis2 = *--iter2; - auto largest = std::max(axis1, axis2); - *--output_shape = largest; + auto largest = std::max(axis1, axis2); + *--output_shape = largest; - if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case - continue; + if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case + continue; - iterator1_.Init(axis1, largest); - iterator2_.Init(axis2, largest); - index++; // Manually increment since we processed one axis - break; + iterator1_.Init(axis1, largest); + iterator2_.Init(axis2, largest); + index++; // Manually increment since we processed one axis + break; + } } for (; index < dimension_count_min; index++) { diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index a3aa724ab410d..225754141a6d7 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -8,6 +8,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "gemm_helper.h" +#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { @@ -27,6 +28,9 @@ class Gemm : public OpKernel { } Status Compute(OpKernelContext* context) const override { + auto ctx_internal = static_cast(context); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + const auto X = context->Input(0); const auto W = context->Input(1); const auto B = context->Input(2); @@ -64,7 +68,7 @@ class Gemm : public OpKernel { } // W * x - math::Gemm( + math::Gemm( trans_A_, trans_B_, M, @@ -75,7 +79,7 @@ class Gemm : public OpKernel { W->template Data(), beta_, y_data, - &CPUMathUtil::Instance()); + tp); FuseActivation(activation_, y_data, M * N, leaky_relu_alpha_); diff --git a/onnxruntime/core/providers/cpu/math/logsoftmax.cc b/onnxruntime/core/providers/cpu/math/logsoftmax.cc index 281031e71568e..19fbb9897c699 100644 --- a/onnxruntime/core/providers/cpu/math/logsoftmax.cc +++ b/onnxruntime/core/providers/cpu/math/logsoftmax.cc @@ -4,6 +4,8 @@ #include "core/providers/cpu/math/logsoftmax.h" #include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_context_internal.h" + #include "core/providers/common.h" #include "core/providers/cpu/math/softmax_shared.h" #include "core/util/math.h" @@ -12,6 +14,9 @@ namespace onnxruntime { template <> Status LogSoftmax::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *tensor_pointer; @@ -32,7 +37,7 @@ Status LogSoftmax::Compute(OpKernelContext* ctx) const { const bool logarithmic = true; auto status = SoftmaxCPU(N, D, X.template Data(), Ydata, - scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data()); + scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp); return status; } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 539157e92bd95..4f4bacc34baeb 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/math/matmul.h" #include "core/util/math.h" @@ -53,6 +53,9 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( template Status MatMul::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); + const auto* left_X = ctx->Input(0); const auto* right_X = ctx->Input(1); @@ -69,7 +72,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { static_cast(helper.K()), left_X->template Data() + helper.LeftOffsets()[i], right_X->template Data() + helper.RightOffsets()[i], - Y->template MutableData() + helper.OutputOffsets()[i]); + Y->template MutableData() + helper.OutputOffsets()[i], thread_pool); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/math/matmul_helper.h b/onnxruntime/core/providers/cpu/math/matmul_helper.h index af82037a7c465..e5095e0ea1382 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_helper.h +++ b/onnxruntime/core/providers/cpu/math/matmul_helper.h @@ -29,9 +29,8 @@ class MatMulComputeHelper { M_ = left_shape.SizeToDimension(left_num_dims - 1); K_ = left_shape[left_num_dims - 1]; N_ = right_shape[right_num_dims - 1]; - std::vector output_dims = left_shape.GetDims(); - output_dims[left_num_dims - 1] = N_; - output_shape_ = TensorShape(output_dims); + output_shape_ = left_shape; + output_shape_[left_num_dims - 1] = N_; output_offsets_ = {0}; left_offsets_ = {0}; right_offsets_ = {0}; diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.cc b/onnxruntime/core/providers/cpu/math/matmul_integer.cc index 9a64e3fe42094..eab5434d24fb9 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.cc @@ -1,49 +1,40 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef _MSC_VER -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#endif - #include "core/providers/cpu/math/matmul_integer.h" #include "core/providers/cpu/math/matmul_helper.h" -#include "core/util/gemmlowp_common_wrapper.h" +#include "core/util/qmath.h" +#include "core/providers/common.h" namespace onnxruntime { // only register this operator if low precision computation is enabled. -ONNX_OPERATOR_KERNEL_EX( +ONNX_OPERATOR_TYPED_KERNEL_EX( MatMulInteger, kOnnxDomain, 10, + uint8_t, kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) .TypeConstraint("T3", DataTypeImpl::GetTensorType()), - MatMulInteger); - -Status GemmlowpMultiply(const uint8_t* lhs_data, const uint8_t* rhs_data, - int32_t* result_data, const int lhs_offset, const int rhs_offset, - int m, int n, int k) { - const std::tuple<> empty_pipeline = {}; - // TODO exp ColMajor order for rhs and result. That may be faster - const auto matOrder = gemmlowp::MapOrder::RowMajor; - gemmlowp::MatrixMap lhs(lhs_data, m, k); - gemmlowp::MatrixMap rhs(rhs_data, k, n); - gemmlowp::MatrixMap result(result_data, m, n); + MatMulInteger); - gemmlowp::GemmContext gemm_context; - gemmlowp::GemmWithOutputPipeline( - &gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline); - - return Status::OK(); -} +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulInteger, + kOnnxDomain, + 10, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMulInteger); -template<> -Status MatMulInteger::Compute(OpKernelContext* ctx) const { +template <> +Status MatMulInteger::Compute(OpKernelContext* ctx) const { auto a = ctx->Input(0); auto b = ctx->Input(1); ORT_ENFORCE(a != nullptr && b != nullptr); @@ -53,34 +44,79 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) c Tensor* y = ctx->Output(0, helper.OutputShape()); // validate zero points - int32_t a_offset = 0; - int32_t b_offset = 0; + uint8_t a_offset = 0; + uint8_t b_offset = 0; if (has_a_zero_point_) { auto a_zero_point = ctx->Input(2); - ORT_ENFORCE(a_zero_point->Shape().NumDimensions() == 0 || - (a_zero_point->Shape().NumDimensions() == 1 && a_zero_point->Shape().GetDims().size() == 1), - "Currently only scalar zero_point is supported. TODO: add per channel zero point support."); + ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), + "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); a_offset = static_cast(*a_zero_point->template Data()); } if (has_b_zero_point_) { auto b_zero_point = ctx->Input(3); - ORT_ENFORCE(b_zero_point->Shape().NumDimensions() == 0 || - (b_zero_point->Shape().NumDimensions() == 1 && b_zero_point->Shape().GetDims().size() == 1), - "Currently only scalar zero_point is supported. TODO: add per channel zero point support."); + ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point), + "MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1"); b_offset = static_cast(*b_zero_point->template Data()); } for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - GemmlowpMultiply(a->template Data() + helper.LeftOffsets()[i], - b->template Data() + helper.RightOffsets()[i], - y->template MutableData() + helper.OutputOffsets()[i], - a_offset, - b_offset, - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K())); + QGemmu8u8_s32(static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + a->template Data() + helper.LeftOffsets()[i], + static_cast(helper.K()), + a_offset, + b->template Data() + helper.RightOffsets()[i], + static_cast(helper.N()), + b_offset, + y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.N()), + nullptr); } + return Status::OK(); +} +template <> +Status MatMulInteger::Compute(OpKernelContext* ctx) const { + auto a = ctx->Input(0); + auto b = ctx->Input(1); + ORT_ENFORCE(a != nullptr && b != nullptr); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape())); + Tensor* y = ctx->Output(0, helper.OutputShape()); + + if (has_a_zero_point_ || has_b_zero_point_) { + // currently zero point is only supported in Gemmlowp path above + // in future, the selection of Eigen/Gemmlowp/mklml/etc. should be in a common math library like SGEMM + + auto IsZeroPointTensorAllZero = [](OpKernelContext* ctx, int input_idx) -> bool { + auto t = ctx->Input(input_idx); + ORT_ENFORCE(t->Shape().NumDimensions() <= 1 && t->Shape().Size() == 1, + "Currently only scalar zero_point is supported. TODO: add per channel zero point support."); + ORT_ENFORCE(t->DataType() == DataTypeImpl::GetType() || + t->DataType() == DataTypeImpl::GetType()); + auto data = reinterpret_cast(t->DataRaw()); + auto vec = std::vector(data, data + t->Shape().Size()); + return std::all_of(vec.begin(), vec.end(), [](int8_t v) { return v == 0; }); + }; + + if ((has_a_zero_point_ && !IsZeroPointTensorAllZero(ctx, 2)) || + (has_b_zero_point_ && !IsZeroPointTensorAllZero(ctx, 3))) { + ORT_NOT_IMPLEMENTED("MatMulInteger: Unsupported input types with zero point"); + } + } + + // NOTE: Eigen based implementation is a reference implementation for accuracy only + for (int i = 0; i < static_cast(helper.OutputOffsets().size()); i++) { + EigenCastGEMM( + a->template Data() + helper.LeftOffsets()[i], + b->template Data() + helper.RightOffsets()[i], + y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K())); + } return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.h b/onnxruntime/core/providers/cpu/math/matmul_integer.h index d9b5bbfbc9361..36e9c11707674 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.h +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.h @@ -9,14 +9,14 @@ namespace onnxruntime { -template +template class MatMulInteger final : public OpKernel { public: MatMulInteger(const OpKernelInfo& info) : OpKernel(info) { has_a_zero_point_ = false; has_b_zero_point_ = false; if (info.GetInputCount() > 2) { - has_a_zero_point_ = true; + has_a_zero_point_ = true; } if (info.GetInputCount() > 3) { has_b_zero_point_ = true; @@ -29,4 +29,4 @@ class MatMulInteger final : public OpKernel { bool has_a_zero_point_; bool has_b_zero_point_; }; -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc index b3cf0dcbe7094..164b58f208f70 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc @@ -1,14 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef _MSC_VER -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#endif - #include "core/providers/cpu/math/quantize_linear_matmul.h" #include "core/providers/cpu/math/matmul_helper.h" -#include "core/util/gemmlowp_common_wrapper.h" +#include "core/providers/common.h" namespace onnxruntime { @@ -24,55 +19,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), QLinearMatMul); -Status GemmlowpMultiply(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, - const int lhs_offset, const int rhs_offset, const int result_offset, - int m, int n, int k, int32_t int_multiplier, int32_t right_shift) { - gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; - quantize_down_stage.result_offset_after_shift = result_offset; - quantize_down_stage.result_fixedpoint_multiplier = int_multiplier; - quantize_down_stage.result_shift = right_shift; - gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; - const auto& output_pipeline = std::make_tuple(quantize_down_stage, saturating_cast_stage); - - // TODO exp ColMajor order for rhs and result. That may be faster - const auto matOrder = gemmlowp::MapOrder::RowMajor; - gemmlowp::MatrixMap lhs(lhs_data, m, k); - gemmlowp::MatrixMap rhs(rhs_data, k, n); - gemmlowp::MatrixMap result(result_data, m, n); - - gemmlowp::GemmContext gemm_context; - gemmlowp::GemmWithOutputPipeline( - &gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, output_pipeline); - - return Status::OK(); -} - -void QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multiplier, int* right_shift) { - auto* fp_as_bits = reinterpret_cast(&fp_multiplier); - auto current_exponent = (*fp_as_bits >> 23); - // bring multiplier in [.5,1) range and calculate the shift - auto bumped_multiplier_as_bits = - (*fp_as_bits & UINT32_C(0x007fffff)) | UINT32_C(0x3f000000); - auto* bumped_multiplier = reinterpret_cast(&bumped_multiplier_as_bits); - auto shift = 126 - current_exponent; - // convert to fixed point number - auto int_multiplier = static_cast(std::round(*bumped_multiplier * (1ll << 31))); - - *integer_multiplier = static_cast(int_multiplier); - *right_shift = shift; -} - -void ScaleAndZeropointPairValidationHelper(const Tensor* scale, const Tensor* zeropoint) { - ORT_ENFORCE(scale->Shape().NumDimensions() == 0 || - (scale->Shape().NumDimensions() == 1 && scale->Shape().GetDims().size() == 1), - "scale must be a scalar"); - ORT_ENFORCE(zeropoint->Shape().NumDimensions() == 0 || - (zeropoint->Shape().NumDimensions() == 1 && zeropoint->Shape().GetDims().size() == 1), - "zeropoint must be a scalar"); -} - -template<> +template <> Status QLinearMatMul::Compute(OpKernelContext* ctx) const { auto a = ctx->Input(0); auto b = ctx->Input(3); @@ -82,16 +29,27 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) c ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape())); Tensor* y = ctx->Output(0, helper.OutputShape()); - // validate scale and zero points + // validate offsets + auto a_offset = ctx->Input(2); + auto b_offset = ctx->Input(5); + auto y_offset = ctx->Input(7); + ORT_ENFORCE(IsScalarOr1ElementVector(a_offset), + "QLinearMatmul : input zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(b_offset), + "QLinearMatmul : weight zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(y_offset), + "QLinearMatmul : result zero point must be a scalar or 1D tensor of size 1"); + + // validate scale auto a_scale = ctx->Input(1); - auto a_zero_point = ctx->Input(2); - ScaleAndZeropointPairValidationHelper(a_scale, a_zero_point); auto b_scale = ctx->Input(4); - auto b_zero_point = ctx->Input(5); - ScaleAndZeropointPairValidationHelper(b_scale, b_zero_point); auto y_scale = ctx->Input(6); - auto y_zero_point = ctx->Input(7); - ScaleAndZeropointPairValidationHelper(y_scale, y_zero_point); + ORT_ENFORCE(IsScalarOr1ElementVector(a_scale), + "QLinearMatmul : input scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(b_scale), + "QLinearMatmul : weight scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(y_scale), + "QLinearMatmul : result scale must be a scalar or 1D tensor of size 1"); auto a_scale_data = *(a_scale->template Data()); auto b_scale_data = *(b_scale->template Data()); @@ -103,17 +61,17 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) c QuantizeMultiplier(real_multiplier, &integer_multiplier, &right_shift); for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - GemmlowpMultiply(a->template Data() + helper.LeftOffsets()[i], - b->template Data() + helper.RightOffsets()[i], - y->template MutableData() + helper.OutputOffsets()[i], - *a_zero_point->template Data(), - *b_zero_point->template Data(), - *y_zero_point->template Data(), - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - integer_multiplier, - right_shift); + GemmlowpMultiplyu8u8_u8(a->template Data() + helper.LeftOffsets()[i], + b->template Data() + helper.RightOffsets()[i], + y->template MutableData() + helper.OutputOffsets()[i], + *a_offset->template Data(), + *b_offset->template Data(), + *y_offset->template Data(), + static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + integer_multiplier, + right_shift); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h index 778bb03ca0e84..aada308756e85 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h @@ -6,6 +6,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/util/math_cpuonly.h" +#include "core/util/gemmlowp_common.h" namespace onnxruntime { @@ -16,6 +17,6 @@ class QLinearMatMul final : public OpKernel { } Status Compute(OpKernelContext* context) const override; - + }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/softmax.cc b/onnxruntime/core/providers/cpu/math/softmax.cc index 9242967901e46..542e20e79f79c 100644 --- a/onnxruntime/core/providers/cpu/math/softmax.cc +++ b/onnxruntime/core/providers/cpu/math/softmax.cc @@ -4,6 +4,7 @@ #include "core/providers/cpu/math/softmax.h" #include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/common.h" #include "core/providers/cpu/math/softmax_shared.h" #include "core/util/math.h" @@ -12,6 +13,9 @@ namespace onnxruntime { template <> Status Softmax::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *tensor_pointer; @@ -34,7 +38,7 @@ Status Softmax::Compute(OpKernelContext* ctx) const { const bool logarithmic = false; auto status = SoftmaxCPU(N, D, X.template Data(), Ydata, - scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data()); + scale_.data(), sum_multiplier_.data(), logarithmic, rowmax_.data(), tp); return status; } diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 7dd3a10cfc598..18277f6b4137c 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -31,6 +31,7 @@ #endif #include "core/providers/cpu/math/softmax_shared.h" + #include "core/util/math.h" #include "core/util/math_cpuonly.h" @@ -46,7 +47,7 @@ common::Status SoftmaxCPU(const int64_t N, float* scale, const float* sum_multiplier, bool logarithmic, - float* rowmax) { + float* rowmax, onnxruntime::concurrency::ThreadPool* tp) { // the Math functions SoftmaxCPU uses only support int32_t as input, so enforce that if (N * D > INT32_MAX || N > INT32_MAX || D > INT32_MAX) { std::ostringstream ss; @@ -65,7 +66,7 @@ common::Status SoftmaxCPU(const int64_t N, // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, tp); // Exponentiation math::Exp(nd, Ydata, Ydata, nullptr); diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.h b/onnxruntime/core/providers/cpu/math/softmax_shared.h index 3439b9717f051..26ffeb193fe4f 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.h +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.h @@ -6,6 +6,9 @@ #include "core/common/status.h" namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} /** Calculate Softmax using CPU memory. @param N Number of rows @@ -18,5 +21,5 @@ Calculate Softmax using CPU memory. @param rowmax Storage for calculation of maximum in each row. Size must be >= N. */ common::Status SoftmaxCPU(int64_t N, int64_t D, const float* Xdata, float* Ydata, float* scale, - const float* sum_multiplier, bool logarithmic, float* rowmax); + const float* sum_multiplier, bool logarithmic, float* rowmax, concurrency::ThreadPool* tp); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.cc b/onnxruntime/core/providers/cpu/ml/label_encoder.cc index 4a2ac686b4480..b497300a72c89 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.cc +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.cc @@ -9,15 +9,16 @@ using namespace ::onnxruntime::common; namespace onnxruntime { namespace ml { -ONNX_CPU_OPERATOR_ML_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL( LabelEncoder, - 1, + 1, 1, KernelDefBuilder().TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}) + .SinceVersion(1, 2), LabelEncoder); Status LabelEncoder::Compute(OpKernelContext* context) const { @@ -67,5 +68,107 @@ Status LabelEncoder::Compute(OpKernelContext* context) const { return Status::OK(); } +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + float_string, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_floats"; + _value_field_name = "values_strings"; + info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + string_float, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_strings"; + _value_field_name = "values_floats"; + info.GetAttrOrDefault("default_float", &_default_value, -0.0f); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + int64_float, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_int64s"; + _value_field_name = "values_floats"; + info.GetAttrOrDefault("default_float", &_default_value, -0.0f); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + float_int64, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_floats"; + _value_field_name = "values_int64s"; + info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + int64_string, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2) + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_int64s"; + _value_field_name = "values_strings"; + info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + string_int64, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2) + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_strings"; + _value_field_name = "values_int64s"; + info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); +}; + } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.h b/onnxruntime/core/providers/cpu/ml/label_encoder.h index 597cf240c6ed4..0f7c59b5740a0 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.h +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.h @@ -43,5 +43,67 @@ class LabelEncoder final : public OpKernel { int64_t default_int_; }; +template +class LabelEncoder_2 final : public OpKernel { + public: + LabelEncoder_2(const OpKernelInfo& info) : OpKernel(info) { + // Let the specialized member function to tell which fields to load. + InitializeSomeFields(info); + + std::vector keys; + std::vector values; + + ORT_ENFORCE(info.GetAttrs(_key_field_name, keys).IsOK()); + ORT_ENFORCE(info.GetAttrs(_value_field_name, values).IsOK()); + + auto num_keys = keys.size(); + auto num_values = values.size(); + ORT_ENFORCE(num_keys == num_values, + "The ", _key_field_name, " and ", _value_field_name, " attribtues in LabelEncoder ", + "(name: ", info.node().Name(), ") must have the same length. ", + "However, the number of key is ", num_keys, " and the number of ", + "values is ", num_values, "."); + + for (size_t i = 0; i < num_keys; ++i) + _map[keys[i]] = values[i]; + } + + Status Compute(OpKernelContext* context) const override { + const auto* tensor_pointer = context->Input(0); + if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const Tensor& X = *tensor_pointer; + const TensorShape& shape = X.Shape(); + Tensor& Y = *context->Output(0, TensorShape(shape)); + + auto input = X.template DataAsSpan(); + auto output = Y.template MutableDataAsSpan(); + + for (int64_t i = 0; i < shape.Size(); ++i) { + const auto found = _map.find(input[i]); + if (found == _map.end()) + output[i] = _default_value; + else + output[i] = found->second; + } + + return Status::OK(); + } + + private: + // Specialize this method to set attribute names. For example, if keys' type + // is 64-bit integer, _key_field_name should be "keys_int64s". Field names + // for other types can be found in ONNX spec. + void InitializeSomeFields(const OpKernelInfo& info); + + // A collection of key-value pairs. Each (a_key, a_value) pair + // means that the "a_key" in the input would be mapped to "a_value". + // If _map doesn't contain "a_key", we use _default_value as its output. + std::unordered_map _map; + TValue _default_value; + // ONNX attribute name to load keys. + std::string _key_field_name; + // ONNX attribute name to load values. + std::string _value_field_name; +}; } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/Unpool.cc b/onnxruntime/core/providers/cpu/nn/Unpool.cc index 3b1c16f354a55..853bd05cdd8d0 100644 --- a/onnxruntime/core/providers/cpu/nn/Unpool.cc +++ b/onnxruntime/core/providers/cpu/nn/Unpool.cc @@ -18,9 +18,9 @@ ONNX_CPU_OPERATOR_KERNEL( MaxUnpool, 9, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("I", DataTypeImpl::GetTensorType()) - .TypeConstraint("Y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + // .TypeConstraint("Y", DataTypeImpl::GetTensorType()), MaxUnpool); Status MaxUnpool::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index c3acbd02a62c5..c0091936704d8 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -14,6 +14,7 @@ * limitations under the License. */ /* Modifications Copyright (c) Microsoft. */ +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/nn/conv.h" #include "core/framework/op_kernel_context_internal.h" @@ -24,6 +25,8 @@ namespace onnxruntime { template Status Conv::Compute(OpKernelContext* context) const { size_t num_inputs = OpKernel::Node().InputDefs().size(); + auto ctx_internal = static_cast(context); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); const auto* X = context->Input(0); const auto* W = context->Input(1); @@ -116,7 +119,7 @@ Status Conv::Compute(OpKernelContext* context) const { col_buffer_data, &CPUMathUtil::Instance()); } - math::Gemm( + math::Gemm( CblasNoTrans, CblasNoTrans, M / group_, @@ -127,7 +130,7 @@ Status Conv::Compute(OpKernelContext* context) const { col_buffer_data, 0, Ydata + group_id * Y_offset, - &CPUMathUtil::Instance()); + tp); } if (B != nullptr) { @@ -144,6 +147,9 @@ Status Conv::Compute(OpKernelContext* context) const { } Status Conv::Compute(OpKernelContext* context) const { + auto ctx_internal = static_cast(context); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + size_t num_inputs = OpKernel::Node().InputDefs().size(); const auto* X = context->Input(0); const auto* W = context->Input(1); @@ -186,11 +192,6 @@ Status Conv::Compute(OpKernelContext* context) const { const size_t kernel_rank = kernel_shape.size(); if (kernel_rank == 2 || kernel_rank == 3) { - // Get access to the internal threadpool - // Temporarily derive concurrency parameters without access to session state - auto ctx_internal = static_cast(context); - auto thread_pool = ctx_internal->GetOperatorThreadPool(); - MLAS_CONV_PARAMETERS Parameters; size_t WorkingBufferSize; MlasConvPrepare(&Parameters, @@ -207,7 +208,7 @@ Status Conv::Compute(OpKernelContext* context) const { static_cast(M / group_), &activation_, &WorkingBufferSize, - const_cast(thread_pool)); + tp); auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr; BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); @@ -218,7 +219,7 @@ Status Conv::Compute(OpKernelContext* context) const { Bdata, static_cast(working_buffer.get()), Ydata, - const_cast(thread_pool)); + tp); } else { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); @@ -253,7 +254,7 @@ Status Conv::Compute(OpKernelContext* context) const { static_cast(kernel_shape.size()), col_buffer_data, &CPUMathUtil::Instance()); - math::Gemm( + math::Gemm( CblasNoTrans, CblasNoTrans, M / group_, @@ -264,7 +265,7 @@ Status Conv::Compute(OpKernelContext* context) const { col_buffer_data, 0, Ydata + group_id * Y_offset, - &CPUMathUtil::Instance()); + tp); } MlasActivation(&activation_, Ydata, Bdata, M, output_image_size, output_image_size); diff --git a/onnxruntime/core/providers/cpu/nn/conv_integer.cc b/onnxruntime/core/providers/cpu/nn/conv_integer.cc index fbd182312f554..534cb75a6e840 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_integer.cc @@ -1,15 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef _MSC_VER -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#endif - #include "core/providers/cpu/nn/conv_integer.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/util/gemmlowp_common_wrapper.h" +#include "core/util/qmath.h" +#include "core/providers/common.h" namespace onnxruntime { @@ -25,30 +21,21 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { + size_t num_inputs = OpKernel::Node().InputDefs().size(); const auto* X = context->Input(0); const auto* W = context->Input(1); - int32_t input_offset = 0; - int32_t filter_offset = 0; + uint8_t input_offset = 0; + uint8_t filter_offset = 0; if (num_inputs >= 3) { const auto* X_Zero_Point = context->Input(2); - if (X_Zero_Point->Shape().NumDimensions() == 0 || - (X_Zero_Point->Shape().NumDimensions() == 1 && X_Zero_Point->Shape().GetDims().size() == 1)) { - input_offset = static_cast(*(X_Zero_Point->Data())); - } else { - //TODO: Add support for per-channel quantization. - return Status(common::ONNXRUNTIME, common::FAIL, "Non per-tensor quantization is not supported now."); - } + ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); + input_offset = *(X_Zero_Point->Data()); } if (num_inputs >= 4) { const auto* W_Zero_Point = context->Input(3); - if (W_Zero_Point->Shape().NumDimensions() == 0 || - (W_Zero_Point->Shape().NumDimensions() == 1 && W_Zero_Point->Shape().GetDims().size() == 1)) { - filter_offset = static_cast(*(W_Zero_Point->Data())); - } else { - //TODO: Add support for per-channel quantization. - return Status(common::ONNXRUNTIME, common::FAIL, "Non per-tensor quantization is not supported now."); - } + ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); + filter_offset = *(W_Zero_Point->Data()); } const int64_t N = X->Shape()[0]; @@ -118,27 +105,21 @@ Status ConvInteger::Compute(OpKernelContext* context) const { static_cast(kernel_shape.size()), col_buffer_data, &CPUMathUtil::Instance(), - false, - input_offset); - - const uint8_t* filter_data_as_uint8 = W->template Data() + group_id * W_offset; - static const gemmlowp::MapOrder ResultOrder = gemmlowp::MapOrder::RowMajor; - static const gemmlowp::MapOrder LhsOrder = gemmlowp::MapOrder::RowMajor; - static const gemmlowp::MapOrder RhsOrder = gemmlowp::MapOrder::RowMajor; - gemmlowp::MatrixMap lhs( - filter_data_as_uint8, static_cast(M / group_), static_cast(kernel_dim)); - gemmlowp::MatrixMap rhs( - col_buffer_data, static_cast(kernel_dim), static_cast(output_image_size)); - gemmlowp::MatrixMap result( - Ydata + group_id * Y_offset, static_cast(M / group_), static_cast(output_image_size)); - const std::tuple<> empty_pipeline = {}; - - gemmlowp::GemmContext gemm_context; - // TODO: worker thread pool needs to be handled. - gemmlowp::GemmWithOutputPipeline( - &gemm_context, lhs, rhs, &result, -filter_offset, -input_offset, - empty_pipeline); + false, + input_offset); + + QGemmu8u8_s32(static_cast(M / group_), + static_cast(output_image_size), + static_cast(kernel_dim), + W->template Data() + group_id * W_offset, + static_cast(kernel_dim), + filter_offset, + col_buffer_data, + static_cast(output_image_size), + input_offset, + Ydata + group_id * Y_offset, + static_cast(output_image_size), + nullptr); } Xdata += X_offset * group_; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 14f13ccd20198..9fd9cd1502147 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -16,6 +16,8 @@ /* Modifications Copyright (c) Microsoft. */ #include "core/providers/cpu/nn/conv_transpose.h" +#include "core/framework/op_kernel_context_internal.h" + #include "core/util/math.h" #include "core/util/math_cpuonly.h" @@ -228,6 +230,9 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { template Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + auto ctx_internal = static_cast(context); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + size_t num_inputs = OpKernel::Node().InputDefs().size(); Prepare p; bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; @@ -254,7 +259,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ for (auto image_id = 0; image_id < p.N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { // Weight term - math::Gemm( + math::Gemm( CblasTrans, CblasNoTrans, kernel_dim, @@ -265,7 +270,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ Xdata + group_id * X_offset, 0, col_buffer_data, - &CPUMathUtil::Instance()); + tp); // Col2im math::Col2im( diff --git a/onnxruntime/core/providers/cpu/nn/pool.cc b/onnxruntime/core/providers/cpu/nn/pool.cc index 367a9256a0c16..47bc8fc856bb3 100644 --- a/onnxruntime/core/providers/cpu/nn/pool.cc +++ b/onnxruntime/core/providers/cpu/nn/pool.cc @@ -190,7 +190,7 @@ Status PoolBase::Compute(OpKernelContext* context, MLAS_POOLING_KIND kind) const // Get access to the internal threadpool // Temporarily derive concurrency parameters without access to session state auto ctx_internal = static_cast(context); - auto thread_pool = ctx_internal->GetOperatorThreadPool(); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); MlasPool(kind, pooling_dims, diff --git a/onnxruntime/core/providers/cpu/nn/pool_base.h b/onnxruntime/core/providers/cpu/nn/pool_base.h index 43f81982dd3a9..606ac909f08f1 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_base.h +++ b/onnxruntime/core/providers/cpu/nn/pool_base.h @@ -99,10 +99,13 @@ class LpPool { }; class PoolBase { + private: + static bool IsGlobalPooling(const std::string& op_name) { + return op_name == "GlobalAveragePool" || op_name == "GlobalMaxPool" || op_name == "GlobalLpPool"; + } + protected: - PoolBase(const OpKernelInfo& info) { - op_name_ = info.GetKernelDef().OpName(); - global_pooling_ = (op_name_ == "GlobalAveragePool" || op_name_ == "GlobalMaxPool" || op_name_ == "GlobalLpPool"); + PoolBase(const OpKernelInfo& info) : op_name_(info.GetKernelDef().OpName()), global_pooling_(IsGlobalPooling(op_name_)) { int end; info.GetKernelDef().SinceVersion(&start_version_, &end); @@ -256,8 +259,8 @@ class PoolBase { Status Compute(OpKernelContext* context, MLAS_POOLING_KIND kind) const; protected: - std::string op_name_; - bool global_pooling_{}; + const std::string op_name_; + const bool global_pooling_; bool count_include_pad_{}; int64_t storage_order_{0}; // MaxPool_8 only. 0 is row major, and 1 is column major. Default is 0. int64_t ceil_mode_{0}; // Introduced in MaxPool_10 diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc index 1cf064f7ea9e1..78a53679325e8 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc @@ -1,14 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef _MSC_VER -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#endif - #include "core/providers/cpu/nn/qlinearconv.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" +#include "core/providers/common.h" namespace onnxruntime { ONNX_OPERATOR_KERNEL_EX( @@ -19,32 +15,40 @@ ONNX_OPERATOR_KERNEL_EX( KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), QLinearConv); Status QLinearConv::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const auto* W = context->Input(3); - // validate scale and zero points - auto input_scale = context->Input(1); + // validate offsets auto input_offset = context->Input(2); - ScaleAndZeropointPairValidationHelper(input_scale, input_offset); - auto filter_scale = context->Input(4); auto filter_offset = context->Input(5); - ScaleAndZeropointPairValidationHelper(filter_scale, filter_offset); - auto result_scale = context->Input(6); auto result_offset = context->Input(7); - ScaleAndZeropointPairValidationHelper(result_scale, result_offset); + ORT_ENFORCE(IsScalarOr1ElementVector(input_offset), + "QLinearConv : input zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(filter_offset), + "QLinearConv : filter zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(result_offset), + "QLinearConv : result zero point must be a scalar or 1D tensor of size 1"); + + // validate scale + auto input_scale = context->Input(1); + auto filter_scale = context->Input(4); + auto result_scale = context->Input(6); + ORT_ENFORCE(IsScalarOr1ElementVector(input_scale), + "QLinearConv : input scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(filter_scale), + "QLinearConv : filter scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(result_scale), + "QLinearConv : result scale must be a scalar or 1D tensor of size 1"); auto input_scale_data = *(input_scale->template Data()); auto filter_scale_data = *(filter_scale->template Data()); auto result_scale_data = *(result_scale->template Data()); - auto input_offset_data = *(input_offset->template Data()); - auto filter_offset_data = *(filter_offset->template Data()); - auto result_offset_data = *(result_offset->template Data()); - const float real_multiplier = (input_scale_data * filter_scale_data) / result_scale_data; int32_t integer_multiplier; int right_shift; @@ -54,7 +58,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { const Tensor* bias = nullptr; if (num_inputs == 9) { bias = context->Input(8); - } + } const int64_t N = X->Shape()[0]; const int64_t C = X->Shape()[1]; @@ -95,7 +99,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { const int64_t kernel_size = TensorShape(kernel_shape).Size(); const int64_t X_offset = C / group_ * input_image_size; const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; + const int64_t W_offset = W->Shape().Size() / group_; const int64_t kernel_dim = C / group_ * kernel_size; const int64_t col_buffer_size = kernel_dim * output_image_size; const int bias_offset = static_cast(M / group_); @@ -124,35 +128,21 @@ Status QLinearConv::Compute(OpKernelContext* context) const { static_cast(kernel_shape.size()), col_buffer_data, &CPUMathUtil::Instance(), - false, - input_offset_data); - - const uint8_t* filter_data_as_uint8 = W->template Data() + group_id * W_offset; - static const gemmlowp::MapOrder MatOrder = gemmlowp::MapOrder::RowMajor; - gemmlowp::MatrixMap lhs( - filter_data_as_uint8, static_cast(M / group_), static_cast(kernel_dim)); - gemmlowp::MatrixMap rhs( - col_buffer_data, static_cast(kernel_dim), static_cast(output_image_size)); - gemmlowp::MatrixMap result( - Ydata + group_id * Y_offset, static_cast(M / group_), static_cast(output_image_size)); - - // TODO: worker thread pool needs to be handled. - gemmlowp::GemmContext gemm_context; - if (bias == nullptr) { - auto output_pipeline = MakeOutputPipelineWithOutBias(result_offset_data, - integer_multiplier, right_shift); - gemmlowp::GemmWithOutputPipeline( - &gemm_context, lhs, rhs, &result, -filter_offset_data, -input_offset_data, - output_pipeline); - } else { - auto output_pipeline = MakeOutputPipelineWithBias(bias->template Data() + group_id * bias_offset, - static_cast(M / group_), result_offset_data, integer_multiplier, right_shift); - gemmlowp::GemmWithOutputPipeline( - &gemm_context, lhs, rhs, &result, -filter_offset_data, -input_offset_data, - output_pipeline); - } + false, + *input_offset->template Data()); + + GemmlowpMultiplyu8u8_u8(W->template Data() + group_id * W_offset, + col_buffer_data, + Ydata + group_id * Y_offset, + *filter_offset->template Data(), + *input_offset->template Data(), + *result_offset->template Data(), + static_cast(M / group_), + static_cast(output_image_size), + static_cast(kernel_dim), + integer_multiplier, + right_shift, + bias == nullptr ? nullptr : bias->template Data() + group_id * bias_offset); } Xdata += X_offset * group_; @@ -161,28 +151,4 @@ Status QLinearConv::Compute(OpKernelContext* context) const { return Status::OK(); } - -void QLinearConv::QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multiplier, int* right_shift) const { - auto* fp_as_bits = reinterpret_cast(&fp_multiplier); - auto current_exponent = (*fp_as_bits >> 23); - // bring multiplier in [.5,1) range and calculate the shift - auto bumped_multiplier_as_bits = - (*fp_as_bits & UINT32_C(0x007fffff)) | UINT32_C(0x3f000000); - auto* bumped_multiplier = reinterpret_cast(&bumped_multiplier_as_bits); - auto shift = 126 - current_exponent; - // convert to fixed point number - auto int_multiplier = static_cast(std::round(*bumped_multiplier * (1ll << 31))); - - *integer_multiplier = static_cast(int_multiplier); - *right_shift = shift; -} - -void QLinearConv::ScaleAndZeropointPairValidationHelper(const Tensor* scale, const Tensor* zeropoint) const { - ORT_ENFORCE(scale->Shape().NumDimensions() == 0 || - (scale->Shape().NumDimensions() == 1 && scale->Shape().GetDims().size() == 1), - "scale must be a scalar"); - ORT_ENFORCE(zeropoint->Shape().NumDimensions() == 0 || - (zeropoint->Shape().NumDimensions() == 1 && zeropoint->Shape().GetDims().size() == 1), - "zeropoint must be a scalar"); -} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.h b/onnxruntime/core/providers/cpu/nn/qlinearconv.h index c5e7919371bc8..9179da587c1f4 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.h +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.h @@ -4,7 +4,7 @@ #pragma once #include "core/providers/cpu/nn/conv_base.h" -#include "core/util/gemmlowp_common_wrapper.h" +#include "core/util/gemmlowp_common.h" namespace onnxruntime { class QLinearConv : public OpKernel, public ConvBase { @@ -12,44 +12,6 @@ class QLinearConv : public OpKernel, public ConvBase { explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { } - Status Compute(OpKernelContext* context) const override; - - void QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multiplier, int* right_shift) const; - - void ScaleAndZeropointPairValidationHelper(const Tensor* scale, const Tensor* zeropoint) const; + Status Compute(OpKernelContext* context) const override; }; - -typedef gemmlowp::VectorMap ColVectorMap; - -inline std::tuple, - gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint, - gemmlowp::OutputStageSaturatingCastToUint8> -MakeOutputPipelineWithBias(const int32_t* bias, - int rows, - std::int32_t result_offset, - std::int32_t result_mult_int, - std::int32_t result_shift) { - ColVectorMap bias_vector(bias, rows); - gemmlowp::OutputStageBiasAddition bias_addition_stage; - bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; - quantize_down_stage.result_offset_after_shift = result_offset; - quantize_down_stage.result_fixedpoint_multiplier = result_mult_int; - quantize_down_stage.result_shift = result_shift; - gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; - return std::make_tuple(bias_addition_stage, quantize_down_stage, saturating_cast_stage); -} - -inline std::tuple -MakeOutputPipelineWithOutBias(std::int32_t result_offset, - std::int32_t result_mult_int, - std::int32_t result_shift) { - gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; - quantize_down_stage.result_offset_after_shift = result_offset; - quantize_down_stage.result_fixedpoint_multiplier = result_mult_int; - quantize_down_stage.result_shift = result_shift; - gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; - return std::make_tuple(quantize_down_stage, saturating_cast_stage); -} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index 66084547810ad..c1a376026ab14 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -141,7 +141,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; - int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4; + int64_t box_offset = batch_index * pc.num_boxes_ * 4; // Filter by score_threshold_ std::priority_queue> sorted_scores_with_index; const auto* class_scores = scores_data + box_score_offset; @@ -158,7 +158,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { } ScoreIndexPair next_top_score; - std::vector selected_indicies_inside_class; + std::vector selected_indices_inside_class; // Get the next box with top score, filter by iou_threshold while (!sorted_scores_with_index.empty()) { next_top_score = sorted_scores_with_index.top(); @@ -166,7 +166,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { bool selected = true; // Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold - for (int64_t selected_index : selected_indicies_inside_class) { + for (int64_t selected_index : selected_indices_inside_class) { if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_, center_point_box, iou_threshold)) { selected = false; @@ -176,10 +176,10 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { if (selected) { if (max_output_boxes_per_class > 0 && - static_cast(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) { + static_cast(selected_indices_inside_class.size()) >= max_output_boxes_per_class) { break; } - selected_indicies_inside_class.push_back(next_top_score.index_); + selected_indices_inside_class.push_back(next_top_score.index_); selected_indices.emplace_back(batch_index, class_index, next_top_score.index_); } } //while diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc index 9453039aa8753..4d27e957e9f44 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -268,7 +268,7 @@ void RoiAlignForward( } // for ph } // for c }; // for n - const_cast(ttp)->ParallelFor(static_cast(n_rois), work_object); + if (ttp != nullptr) const_cast(ttp)->ParallelFor(static_cast(n_rois), work_object); } } // namespace diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 8c9143a238868..b418012574c7e 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -30,15 +30,25 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); +#define REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(x, sinceVersion) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + x, \ + sinceVersion, \ + int64_t, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + x); + REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSumExp, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 1); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMax, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 1); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSum, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSum, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 1); diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index c5be268f59e2d..0dd13269bfacd 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/platform/threadpool.h" +#include "core/framework/op_kernel_context_internal.h" // there's no way to use a raw pointer as the copy destination with std::copy_n // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset @@ -167,7 +169,8 @@ class UniDirectionalGru { UniDirectionalGru(AllocatorPtr allocator, int seq_length, int batch_size, int input_size, int hidden_size, bool linear_before_reset, Direction direction, const gsl::span& bias, const gsl::span& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, - const ActivationFuncs::Entry& activation_func_g, float clip); + const ActivationFuncs::Entry& activation_func_g, float clip, + onnxruntime::concurrency::ThreadPool* ttp); void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, int num_directions, const gsl::span& input_weights, const gsl::span& recurrent_weights, @@ -233,6 +236,8 @@ class UniDirectionalGru { deepcpu::GruOutputGateFuncPtr output_gate_{}; void AllocateBuffers(); + + onnxruntime::concurrency::ThreadPool* ttp_; }; } // namespace detail @@ -263,6 +268,9 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const { template Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { + auto ctx_internal = static_cast(&context); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); + const Tensor& X = *context.Input(0); // inputs. [seq_length, batch_size, input_size] const Tensor& W = *context.Input(1); // weights. [num_directions, 3*hidden_size, input_size] const Tensor& R = *context.Input(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size] @@ -367,7 +375,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { linear_before_reset_, Direction::kForward, bias_1, initial_hidden_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], - clip_); + clip_, thread_pool); fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); @@ -375,7 +383,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { linear_before_reset_, Direction::kReverse, bias_2, initial_hidden_2, activation_funcs_.Entries()[2], activation_funcs_.Entries()[3], - clip_); + clip_, thread_pool); bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2); } else { @@ -383,7 +391,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { linear_before_reset_, direction_, bias_1, initial_hidden_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], - clip_); + clip_, thread_pool); gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); } @@ -412,7 +420,7 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, const gsl::span& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, - const float clip) + const float clip, onnxruntime::concurrency::ThreadPool* ttp) : allocator_(allocator), seq_length_(seq_length), batch_size_(batch_size), @@ -421,7 +429,8 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, linear_before_reset_(linear_before_reset), clip_(clip), direction_(direction), - use_bias_(!bias.empty()) { + use_bias_(!bias.empty()), + ttp_(ttp) { clip_with_bias_ptr_ = use_bias_ ? deepcpu::clip_add_bias : deepcpu::clip_ignore_bias; // setup activation function pointers and alpha/beta values to use with them @@ -540,7 +549,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), input_size_, beta, outputZRH_.begin(), outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, ttp_); DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_); @@ -606,7 +615,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(), hidden_size_, beta, outputZRH_.begin() + out_added_offset, outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, ttp_); DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str, outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3); @@ -622,7 +631,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T hidden_size_, beta, linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output - hidden_size_); + hidden_size_, ttp_); DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_); } @@ -693,7 +702,7 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T hidden_size_, beta, out_H, outputZRH_.end(), - hidden_size_x3); + hidden_size_x3, ttp_); } DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset, diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 8f4e8236981f8..682dabd9262ca 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -9,6 +9,9 @@ #pragma warning(disable : 4996) #endif +#include "core/platform/threadpool.h" +#include "core/framework/op_kernel_context_internal.h" + #include "core/providers/cpu/rnn/deep_cpu_lstm.h" #include "core/common/common.h" @@ -193,7 +196,8 @@ class UniDirectionalLstm { const gsl::span& initial_hidden_state, const gsl::span& initial_cell_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, const ActivationFuncs::Entry& activation_func_h, float clip, - onnxruntime::concurrency::ThreadPool& ttp); + concurrency::ThreadPool& lstm_tp_, + concurrency::ThreadPool* mlas_tp_); void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, int num_directions, const gsl::span& input_weights, const gsl::span& recurrent_weights, @@ -275,7 +279,8 @@ class UniDirectionalLstm { ActivationInfo activation_g_; ActivationInfo activation_h_; - onnxruntime::concurrency::ThreadPool& ttp_; + concurrency::ThreadPool& lstm_tp_; + concurrency::ThreadPool* mlas_tp_; }; } // namespace detail @@ -309,6 +314,9 @@ DeepCpuLstmOp::Compute(OpKernelContext* context) const { template Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { + auto ctx_internal = static_cast(&context); + concurrency::ThreadPool* mlas_thread_pool = ctx_internal->GetOperatorThreadPool(); + auto& logger = context.Logger(); const Tensor& X = *context.Input(0); // inputs. [seq_length, batch_size, input_size] @@ -452,7 +460,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, lstm_tp_, mlas_thread_pool); detail::UniDirectionalLstm bw(alloc, logger, seq_length, batch_size, input_size, hidden_size_, Direction::kReverse, input_forget_, @@ -460,7 +468,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[3], activation_funcs_.Entries()[4], activation_funcs_.Entries()[5], - clip_, ttp_); + clip_, lstm_tp_, mlas_thread_pool); fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); @@ -473,7 +481,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], activation_funcs_.Entries()[2], - clip_, ttp_); + clip_, lstm_tp_, mlas_thread_pool); fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1); @@ -546,7 +554,8 @@ UniDirectionalLstm::UniDirectionalLstm(AllocatorPtr allocator, const ActivationFuncs::Entry& activation_func_g, const ActivationFuncs::Entry& activation_func_h, const float clip, - onnxruntime::concurrency::ThreadPool& ttp) + concurrency::ThreadPool& lstm_tp, + concurrency::ThreadPool* mlas_tp) : allocator_(allocator), logger_(logger), seq_length_(seq_length), @@ -558,7 +567,8 @@ UniDirectionalLstm::UniDirectionalLstm(AllocatorPtr allocator, clip_(clip), use_bias_(!bias.empty()), use_peepholes_(!peephole_weights.empty()), - ttp_(ttp) { + lstm_tp_(lstm_tp), + mlas_tp_(mlas_tp) { activation_f_ = {deepcpu::ActivationFuncByName(activation_func_f.name), activation_func_f.alpha, activation_func_f.beta}; @@ -774,7 +784,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, input_weights.cbegin(), input_weights.cend(), // W[iofc] input_size_, beta, output_iofc_.begin(), output_iofc_.end(), - hidden_size_x4); + hidden_size_x4, mlas_tp_); DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4); @@ -823,7 +833,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, mlas_tp_); DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4); @@ -874,7 +884,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, } }; - ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, ttp_, logger_); + ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, lstm_tp_, logger_); } else { span_T_iter c_prev = batched_internal_state_prev_one_step.begin(); @@ -901,7 +911,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) - hidden_size_x4); + hidden_size_x4, mlas_tp_); span_T_iter batched_output; span_T_iter batched_output_end; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index 606dfbf5b190c..faf32e3a77a2f 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -82,8 +82,8 @@ class DeepCpuLstmOp final : public OpKernel { // across them. mutable due to this. // The alternative would be to create a threadpool in each call to Compute but that would incur thread creation // cost on every call. - mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_LSTM", - static_cast(std::thread::hardware_concurrency())}; + mutable onnxruntime::concurrency::ThreadPool lstm_tp_{"DEEPCPU_LSTM", + static_cast(std::thread::hardware_concurrency())}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/rnn/rnn.cc b/onnxruntime/core/providers/cpu/rnn/rnn.cc index 4030d65a94d45..d26b02f81ae68 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/rnn/rnn.h" #include "core/providers/cpu/rnn/rnn_activation_functors.h" @@ -99,6 +100,8 @@ using EigenMatrixMapRowMajor = Eigen::Map< template <> Status RNN::Compute(OpKernelContext* ctx) const { using namespace rnn::detail; + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); // inputs const Tensor& X = *ctx->Input(0); @@ -160,7 +163,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { } // X * W[direction]^t + B - math::Gemm( + math::Gemm( CblasNoTrans, CblasTrans, static_cast(seq_length * batch_size), @@ -171,7 +174,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { W.template Data() + direction * hidden_size_ * input_size, 1, x_matmul_w_buffer_data, - &CPUMathUtil::Instance()); + tp); for (int64_t t = 0; t < seq_length; t++) { int64_t time_step = isReverse ? (seq_length - t - 1) : t; @@ -181,8 +184,12 @@ Status RNN::Compute(OpKernelContext* ctx) const { const float* h_prev = nullptr; if (t == 0) { - if (initial_h != nullptr) - h_prev = initial_h->template Data(); + if (initial_h != nullptr) { + // the shape of initial_h is [num_directions, batch_size, hidden_size] + // so pick the offset (multiple of Y_frame_size == batch_size * hidden_size_) + // based on the direction + h_prev = initial_h->template Data() + (direction * Y_frame_size); + } } else { if (isReverse) h_prev = Y_buffer_data_current_frame + num_directions * Y_frame_size; @@ -192,7 +199,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { if (h_prev != nullptr) { // H_t_1 * R[direction]^t - math::Gemm( + math::Gemm( CblasNoTrans, CblasTrans, static_cast(batch_size), @@ -203,7 +210,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { R.template Data() + direction * hidden_size_ * hidden_size_, 0, Y_buffer_data_current_frame, - &CPUMathUtil::Instance()); + tp); } else { math::Set(batch_size * hidden_size_, 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance()); } diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index 2e3e5f88d72ec..f1038e63a350e 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -159,7 +159,7 @@ void ComputeGemm(const int M, const float beta, TSpanCIter C, TSpanCIter C_end, - const int ldc) { + const int ldc, concurrency::ThreadPool* tp) { // validate all the inputs // need to use the lda/ldb/ldc strides which should be >= the columns for the span ORT_ENFORCE(lda >= K && ldb >= K && ldc >= N); @@ -167,12 +167,12 @@ void ComputeGemm(const int M, ORT_ENFORCE(B + (N * ldb - (ldb - K)) <= B_end); ORT_ENFORCE(C + (M * ldc - (ldc - N)) <= C_end); - ::onnxruntime::math::GemmEx( + ::onnxruntime::math::GemmEx( CblasNoTrans, CblasTrans, M, N, K, alpha, &*A, lda, &*B, ldb, beta, - &*C, ldc, &CPUMathUtil::Instance()); + &*C, ldc, tp); } // helper to convert a span to a raw pointer diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt index fc7560f5b7696..1d2750e5d3d3e 100644 --- a/onnxruntime/core/providers/cpu/symbols.txt +++ b/onnxruntime/core/providers/cpu/symbols.txt @@ -12,7 +12,7 @@ OrtCompareAllocatorInfo OrtCreateAllocatorInfo OrtCreateCpuAllocatorInfo OrtCreateCustomOpDomain -OrtCreateDefaultAllocator +OrtGetAllocatorWithDefaultOptions OrtCreateEnv OrtCreateEnvWithCustomLogger OrtCreateRunOptions @@ -41,7 +41,6 @@ OrtGetErrorMessage OrtGetStringTensorContent OrtGetStringTensorDataLength OrtGetTensorElementType -OrtGetTensorMemSizeInBytesFromTensorProto OrtGetTensorMutableData OrtGetTensorShapeElementCount OrtGetTensorTypeAndShape @@ -52,7 +51,6 @@ OrtGetValueType OrtGetVersionString OrtIsTensor OrtGetOnnxTypeFromTypeInfo -OrtReleaseAllocator OrtReleaseAllocatorInfo OrtReleaseCustomOpDomain OrtReleaseEnv @@ -64,10 +62,10 @@ OrtReleaseTensorTypeAndShapeInfo OrtReleaseTypeInfo OrtReleaseValue OrtRun -OrtRunCallback OrtRunOptionsGetRunLogVerbosityLevel OrtRunOptionsGetRunTag OrtRunOptionsSetRunLogVerbosityLevel +OrtRunOptionsSetRunLogSeverityLevel OrtRunOptionsSetRunTag OrtRunOptionsEnableTerminate OrtRunOptionsDisableTerminate @@ -82,6 +80,7 @@ OrtSetDimensions OrtSetSessionGraphOptimizationLevel OrtSetSessionLogId OrtSetSessionLogVerbosityLevel +OrtSetSessionLogSeverityLevel +OrtSetOptimizedModelFilePath OrtSetSessionThreadPoolSize OrtSetTensorElementType -OrtTensorProtoToOrtValue diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index c326d25ef17a0..0f8da8eaff2a6 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -10,7 +10,7 @@ #include "Eigen/src/Core/arch/GPU/Half.h" #include "core/common/common.h" -#if defined(USE_MLAS) && defined(_M_AMD64) +#if defined(_M_AMD64) #include "core/mlas/inc/mlas.h" #endif @@ -40,7 +40,7 @@ inline void CastData(const Tensor* in, Tensor* out, const Tens auto out_data = out->template MutableData(); auto in_data = in->template Data(); auto shape_size = shape.Size(); -#if defined(USE_MLAS) && defined(_M_AMD64) +#if defined(_M_AMD64) MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); #else auto in_vector = ConstEigenVectorMap(static_cast(static_cast(in_data)), shape_size); diff --git a/onnxruntime/core/providers/cpu/tensor/compress.cc b/onnxruntime/core/providers/cpu/tensor/compress.cc index e732121adbf02..b3f82bf9fdc2a 100644 --- a/onnxruntime/core/providers/cpu/tensor/compress.cc +++ b/onnxruntime/core/providers/cpu/tensor/compress.cc @@ -9,7 +9,8 @@ namespace onnxruntime { ONNX_CPU_OPERATOR_KERNEL( Compress, 9, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Compress); Status Compress::Compute(OpKernelContext* ctx) const { diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc index afca4d421efe8..0a26ea2a0dd42 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.cc +++ b/onnxruntime/core/providers/cpu/tensor/concat.cc @@ -34,16 +34,17 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep auto& inputs_n = *tensor_pointer; const auto& inputs_n_dims = inputs_n.Shape().GetDims(); const size_t inputs_n_rank = inputs_n_dims.size(); - ORT_ENFORCE(inputs_n_rank == inputs_0_rank, "Ranks of input data are different, cannot concatenate them, " - "expected rank: ", std::to_string(inputs_0_rank), " got: ", std::to_string(inputs_n_rank)); + ORT_ENFORCE(inputs_n_rank == inputs_0_rank, + "Ranks of input data are different, cannot concatenate them. expected rank: ", + inputs_0_rank, " got: ", inputs_n_rank); // Ensure all the other (non-concat) axes match for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) { num_elements *= inputs_n_dims[axis_index]; if (axis_index == p.axis) continue; ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index], - "Non concat axis dimensions must match: Axis ", - axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index], + "Non concat axis dimensions must match: Axis ", + axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index], " and ", inputs_0_dims[axis_index]); } tensor_num_elements[index] = num_elements; @@ -58,7 +59,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep // Calculate the shape of the output tensor std::vector dims(inputs_0_rank); - size_t num_elements = 1; // cache size of the first input along the way + size_t num_elements = 1; // cache size of the first input along the way for (size_t dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) { dims[dimension_index] = inputs_0_dims[dimension_index]; num_elements *= inputs_0_dims[dimension_index]; @@ -66,7 +67,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep tensor_num_elements[0] = num_elements; dims[p.axis] = concat_axis_size; TensorShape output_shape(dims); - + auto& concat_result = *ctx->Output(0, output_shape); p.output_tensor = &concat_result; p.output_num_elements = output_shape.Size(); @@ -75,7 +76,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep // there is no need to proceed further if (p.output_num_elements == 0) return Status::OK(); - + // The output_axis_pitch is the number of elements to add to move to the next split axis in the output p.output_axis_pitch = 1; for (size_t i = inputs_0_rank; i-- > p.axis;) p.output_axis_pitch *= dims[i]; @@ -110,7 +111,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { auto is_string_type = ctx->Input(0)->DataType() == DataTypeImpl::GetType(); - int64_t output_offset = 0; + int64_t initial_output_offset = 0; // initial offset for each input auto element_bytes = p.output_tensor->DataType()->Size(); for (int input_index = 0; input_index < input_count; input_index++) { const auto& prep = p.inputs[input_index]; @@ -124,19 +125,29 @@ Status Concat::Compute(OpKernelContext* ctx) const { // Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch' uint8_t* output = static_cast(p.output_tensor->MutableDataRaw()); - for (size_t idxCopy = 0; idxCopy < input_size / input_axis_pitch; ++idxCopy) { + int64_t cur_out_offset = 0; + int64_t cur_in_offset = 0; + for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) { if (is_string_type) { - for (int idxItem = 0; idxItem < input_axis_pitch; ++idxItem) - reinterpret_cast(output)[output_offset + idxCopy * p.output_axis_pitch + idxItem] = - reinterpret_cast(input)[idxCopy * input_axis_pitch + idxItem]; - } else + size_t out = initial_output_offset + cur_out_offset; + for (int idx_item = 0; idx_item < input_axis_pitch; ++idx_item) { + reinterpret_cast(output)[out + idx_item] = + reinterpret_cast(input)[cur_in_offset + idx_item]; + } + } else { memcpy( - output + (output_offset + idxCopy * p.output_axis_pitch) * element_bytes, - input + idxCopy * input_axis_pitch * element_bytes, + output + (initial_output_offset + cur_out_offset) * element_bytes, + input + cur_in_offset * element_bytes, input_axis_pitch * element_bytes); + } + + cur_out_offset += p.output_axis_pitch; + cur_in_offset += input_axis_pitch; } - output_offset += input_axis_pitch; + + initial_output_offset += input_axis_pitch; } + return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.cc b/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.cc new file mode 100644 index 0000000000000..dafa3a322f5e8 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "dynamicquantizelinear.h" +#include "core/providers/common.h" +#include "core/util/math_cpuonly.h" +#include +#include + +namespace onnxruntime { + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + DynamicQuantizeLinear, + 11, + uint8_t, + KernelDefBuilder() + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DynamicQuantizeLinear); + + +static float RoundHalfToEven(float input) { + std::fesetround(FE_TONEAREST); + auto result = std::nearbyintf(input); + return result; +} + +// formula is Y = X / Scale + ZeroPoint +template +Status DynamicQuantizeLinear::Compute(OpKernelContext* ctx) const { + auto x_ptr = ctx->Input(0); + ORT_ENFORCE(x_ptr != nullptr); + auto& x = *x_ptr; + const auto* x_data = x.template Data(); + + auto& y = *ctx->Output(0, x.Shape()); + std::vector shape({}); + auto& y_scale = *ctx->Output(1, shape); + auto& y_zeropoint = *ctx->Output(2, shape); + + // find quantization range min and max + float qmax = std::numeric_limits::max(); + float qmin = std::numeric_limits::min(); + // Adjust the int8 range to -127 to 127 so that zero point can be 0 + if (qmin == -128) { + qmin = -127; + } + + // find input range min and max + auto min = ConstEigenVectorMap(x_data, x.Shape().Size()).minCoeff(); + min = std::min(min, qmin); + auto max = ConstEigenVectorMap(x_data, x.Shape().Size()).maxCoeff(); + max = std::max(max, qmin); + + // find scale and zero point + auto scale = (max - min) / (qmax - qmin); + auto* output_scale = y_scale.template MutableData(); + *output_scale = scale; + + const auto initial_zero_point = qmin - min / scale; + auto zero_point = static_cast(RoundHalfToEven(std::max(qmin, std::min(qmax, initial_zero_point)))); + auto* output_zp = y_zeropoint.template MutableData(); + *output_zp = zero_point; + + // quantize the data + auto* output = y.template MutableData(); + const auto num_of_elements = x.Shape().Size(); + + for (int i = 0; i < num_of_elements; ++i) { + output[i] = static_cast(clamp(RoundHalfToEven(static_cast(x_data[i] / scale)) + zero_point, qmin, qmax)); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.h b/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.h new file mode 100644 index 0000000000000..fa15cc9126cb6 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/dynamicquantizelinear.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +template +class DynamicQuantizeLinear final : public OpKernel { + public: + DynamicQuantizeLinear(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; + +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/identity_op.cc b/onnxruntime/core/providers/cpu/tensor/identity_op.cc index b7fe35c73f039..f431d9de70185 100644 --- a/onnxruntime/core/providers/cpu/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/identity_op.cc @@ -10,7 +10,8 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 7, 9, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), IdentityOp); ONNX_CPU_OPERATOR_KERNEL( diff --git a/onnxruntime/core/providers/cpu/tensor/nonzero_op.cc b/onnxruntime/core/providers/cpu/tensor/nonzero_op.cc index 7c725bba0f8b6..ef16693dc73c4 100644 --- a/onnxruntime/core/providers/cpu/tensor/nonzero_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/nonzero_op.cc @@ -23,7 +23,7 @@ namespace onnxruntime { // start with a subset of types, enable more as needed... NONZERO_TYPED_KERNEL(bool) -//NONZERO_TYPED_KERNEL(uint8_t) +NONZERO_TYPED_KERNEL(uint8_t) //NONZERO_TYPED_KERNEL(uint16_t) //NONZERO_TYPED_KERNEL(uint32_t) //NONZERO_TYPED_KERNEL(uint64_t) @@ -40,24 +40,6 @@ NONZERO_TYPED_KERNEL(float) #undef NONZERO_TYPED_KERNEL_WITH_TYPE_NAME #undef NONZERO_TYPED_KERNEL -namespace { -void IncrementCoordinate(const TensorShape& shape, std::vector* coordinate) { - assert(coordinate->size() == shape.NumDimensions()); - - size_t i = 0; - const size_t i_end = coordinate->size(); - for (; i < i_end; ++i) { - const size_t i_from_back = i_end - i - 1; - if ((*coordinate)[i_from_back] != shape[i_from_back] - 1) break; - (*coordinate)[i_from_back] = 0; - } - - if (i < i_end) { - ++(*coordinate)[i_end - i - 1]; - } -} -} // namespace - template Status NonZero::Compute(OpKernelContext* context) const { const auto X = context->Input(0); @@ -71,19 +53,37 @@ Status NonZero::Compute(OpKernelContext* context) const { // reserve enough space for indices for every element of X non_zero_indices_buffer.reserve(X_shape.Size() * coordinate_size); + const T* data = X->Data(); + if (X_shape.IsScalar()) { - const T& value = *(X->Data()); + const T& value = *data; if (value != T{}) { non_zero_indices_buffer.push_back(0); } } else { std::vector coordinate(coordinate_size, 0); - for (const T& value : X->DataAsSpan()) { + + // as we iterate the entries, increment the coordinate for the current entry + // e.g. if shape is {2,2}, we start with 0,0 increment to 0,1 increment to 1,0 and finally 1,1 + auto increment_coordinate = [&coordinate, &coordinate_size, &X_shape]() { + for (int64_t idx = coordinate_size - 1; idx >= 0; --idx) { + int64_t& cur_coord = coordinate[idx]; + if (cur_coord != X_shape[idx] - 1) { + ++cur_coord; + break; + } + cur_coord = 0; + } + }; + + for (size_t i = 0, end = X_shape.Size(); i < end; ++i) { + const T& value = *data++; if (value != T{}) { non_zero_indices_buffer.insert(non_zero_indices_buffer.end(), coordinate.begin(), coordinate.end()); } - IncrementCoordinate(X_shape, &coordinate); + + increment_coordinate(); } } diff --git a/onnxruntime/core/providers/cpu/tensor/onehot.cc b/onnxruntime/core/providers/cpu/tensor/onehot.cc index 1dfbaaf37640f..c4f0c2479a069 100644 --- a/onnxruntime/core/providers/cpu/tensor/onehot.cc +++ b/onnxruntime/core/providers/cpu/tensor/onehot.cc @@ -18,8 +18,9 @@ limitations under the License. #include "core/util/eigen_common_wrapper.h" #include "core/platform/env.h" +#ifndef EIGEN_USE_THREADS #define EIGEN_USE_THREADS - +#endif using namespace ::onnxruntime::common; using namespace std; @@ -46,6 +47,8 @@ REG_ONE_HOT_OP(float, int64_t, int64_t); REG_ONE_HOT_OP(int64_t, string, int64_t); REG_ONE_HOT_OP(float, string, int64_t); REG_ONE_HOT_OP(int64_t, float, int64_t); +REG_ONE_HOT_OP(int32_t, float, int32_t); +REG_ONE_HOT_OP(int32_t, float, float); REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests diff --git a/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc b/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc index 5846bc102f565..e345ad4da3cd8 100644 --- a/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc @@ -63,21 +63,22 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { ONNX_CPU_OPERATOR_TYPED_KERNEL( QuantizeLinear, 10, - float, + uint8_t, KernelDefBuilder() .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("y_scale", DataTypeImpl::GetTensorType()) .TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType()) .TypeConstraint("y", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -// clamp doesn't exist in the version of that we're using, so -// make a local one. -static float clamp(float v, float lo, float hi) { - if (v < lo) return lo; - if (v > hi) return hi; - return v; -} + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + QuantizeLinear, + 10, + int8_t, + KernelDefBuilder() + .TypeConstraint("x", DataTypeImpl::GetTensorType()) + .TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType()) + .TypeConstraint("y", DataTypeImpl::GetTensorType()), + QuantizeLinear); static float RoundHalfToEven(float input) { std::fesetround(FE_TONEAREST); @@ -85,9 +86,9 @@ static float RoundHalfToEven(float input) { return result; } -template <> +template // formula is Y = X / Scale + ZeroPoint -Status QuantizeLinear::Compute(OpKernelContext* ctx) const { +Status QuantizeLinear::Compute(OpKernelContext* ctx) const { auto& x = *ctx->Input(0); auto& y_scale = *ctx->Input(1); auto& y_zero_point = *ctx->Input(2); @@ -102,14 +103,18 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(scale_shape.NumDimensions() == 0 || (scale_shape.NumDimensions() == 1 && scale_shape.GetDims().size() == 1), "x_scale must be a scalar."); ORT_ENFORCE(zero_point_shape.NumDimensions() == 0 || (zero_point_shape.NumDimensions() == 1 && zero_point_shape.GetDims().size() == 1), "x_zero_point must be a scalar."); - const uint8_t zero_point = *(y_zero_point.template Data()); + const T zero_point = *(y_zero_point.template Data()); const float scale = *(y_scale.template Data()); const auto* input = x.template Data(); - auto* output = y.template MutableData(); + auto* output = y.template MutableData(); const auto num_of_elements = x_shape.Size(); + const float qmax = std::numeric_limits::max(); + const float qmin_default = std::numeric_limits::min(); + // adjust qmin for int8 inputs. This is required to keep zero point as zero + const float qmin = qmin_default == -128 ? -127 : qmin_default; for (int i = 0; i < num_of_elements; ++i) { - output[i] = static_cast(clamp(RoundHalfToEven(static_cast(input[i]/scale)) + zero_point, 0.0f, float(UINT8_MAX))); + output[i] = static_cast(clamp(RoundHalfToEven(static_cast(input[i]/scale)) + zero_point, qmin, qmax)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/size.cc b/onnxruntime/core/providers/cpu/tensor/size.cc index 675c14b8cfee6..75bdd5bec204e 100644 --- a/onnxruntime/core/providers/cpu/tensor/size.cc +++ b/onnxruntime/core/providers/cpu/tensor/size.cc @@ -41,7 +41,8 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()})), + DataTypeImpl::GetTensorType()})) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Size); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/tile.cc b/onnxruntime/core/providers/cpu/tensor/tile.cc index 984f490adec9d..1b0ab391fbe41 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.cc +++ b/onnxruntime/core/providers/cpu/tensor/tile.cc @@ -34,7 +34,8 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Tile); Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tensor, const int64_t* repeats, TensorAxisCounters& input_counters, const TensorPitches& output_pitches, size_t element_size) { diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 95605dbef4a68..3dcfcb47a353b 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/tensor/upsample.h" #include +#include using namespace onnxruntime::common; using namespace std; @@ -61,14 +62,18 @@ Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales) { + const vector& scales, + bool is_resize) { if (!input || !output) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value is nullptr" : + "Upsample: input/output value is nullptr"); if (input_shape.NumDimensions() != output_shape.NumDimensions()) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" : + "Upsample: input/output value's dimension mismatch"); if (input_shape.NumDimensions() == 0) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Upsample: input shape needs to be at least a single dimension."); + is_resize ? "Resize: input shape needs to be at least a single dimension" : + "Upsample: input shape needs to be at least a single dimension."); } int64_t n_dim = static_cast(input_shape.NumDimensions()); @@ -192,11 +197,14 @@ Status upsampleLiner(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales) { + const vector& scales, + bool is_resize) { if (!input || !output) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input / output value is nullptr" : + "Upsample: input / output value is nullptr"); if (input_shape.NumDimensions() != output_shape.NumDimensions()) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" : + "Upsample: input/output value's dimension mismatch"); auto n_dim = input_shape.NumDimensions(); for (size_t i = 0, size = output_shape.Size(); i < size; i++) { std::vector val1; @@ -242,6 +250,11 @@ Status upsampleLiner(const T* input, return Status::OK(); } +// The following method supports a 4-D input in 'Linear mode' +// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes +// the scale values for the outermost 2 dimensions are 1. +// This is the common use-case where the 4-D input (batched multi-channel images) +// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template void upsampleBilinear( int64_t batch_size, @@ -327,9 +340,10 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& dims = X->Shape().GetDims(); - if (dims.size() != scales.size()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales."); - } + if (dims.size() != scales.size()) + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor's dimension does not match the scales." : + "Upsample: input tensor's dimension does not match the scales."); bool no_scale = true; std::vector Y_dims; @@ -348,26 +362,33 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector(X->template Data(), Y->template MutableData(), X->Shape(), Y->Shape(), scales); + return UpsampleNearest(X->template Data(), Y->template MutableData(), X->Shape(), Y->Shape(), scales, is_resize); case UpsampleMode::LINEAR: { - //What's the correct behavior of linear mode is not clear right now, - //Only support bilinear with 4D tensor to keep consistent with previous behavior - if (dims.size() != 4) - return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode upsample only support 4-D tensor with NCHW layout"); + //The correct behavior of 'linear' mode for an N-D input is not clear right now, + //so only support 'bilinear' with 2-D or 4-D input tensor with outermost 2 scales as 1 in the 4-D case + if (dims.size() != 2 && dims.size() != 4) { + std::ostringstream oss; + oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the "; + oss << (is_resize ? "Resize operator" : "Upsample operator"); + return Status(ONNXRUNTIME, FAIL, oss.str()); + } - const int64_t batch_size = dims[0]; - const int64_t num_channels = dims[1]; - const int64_t input_height = dims[2]; - const int64_t input_width = dims[3]; + bool is_2D = dims.size() == 2; + const int64_t batch_size = is_2D ? 1 : dims[0]; + const int64_t num_channels = is_2D ? 1 : dims[1]; + const int64_t input_height = is_2D ? dims[0] : dims[2]; + const int64_t input_width = is_2D ? dims[1] : dims[3]; AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); upsampleBilinear(batch_size, num_channels, input_height, input_width, - scales[2], scales[3], X->template Data(), Y->template MutableData(), alloc); + is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3], + X->template Data(), Y->template MutableData(), alloc); return Status::OK(); } default: - return Status(ONNXRUNTIME, FAIL, "Upsample: unexpected mode"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: unexpected mode" : "Upsample: unexpected mode"); } } @@ -380,9 +401,9 @@ Status Upsample::Compute(OpKernelContext* context) const { const auto* scales = context->Input(1); ORT_ENFORCE(scales != nullptr); int64_t scales_size = scales->Shape().Size(); - std::vector scales_arrary(scales_size); - ParseScalesData(scales, scales_arrary); - return BaseCompute(context, scales_arrary); + std::vector scales_array(scales_size); + ParseScalesData(scales, scales_array); + return BaseCompute(context, scales_array); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 5c57295af5195..97b41e0915d89 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -72,9 +72,10 @@ class UpsampleBase { } if (UpsampleMode::LINEAR == mode) { - ORT_ENFORCE(scales.size() == 4, "Upsample: linear mode upsample only support bilinear with 4 dimension."); - ORT_ENFORCE(((scales[0] == 1) && (scales[1] == 1)), - "Upsample: linear mode upsample only support bilinear, the first 2 scales should be 1."); + ORT_ENFORCE(scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1), + "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the ", + is_resize ? "Resize operator" : "Upsample operator"); } } diff --git a/onnxruntime/core/providers/cpu/tensor/where_op.cc b/onnxruntime/core/providers/cpu/tensor/where_op.cc index 21e0243dcf46a..bd946c4619f1e 100644 --- a/onnxruntime/core/providers/cpu/tensor/where_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/where_op.cc @@ -29,7 +29,7 @@ namespace onnxruntime { //WHERE_TYPED_KERNEL(int8_t) //WHERE_TYPED_KERNEL(int16_t) WHERE_TYPED_KERNEL(int32_t) -//WHERE_TYPED_KERNEL(int64_t) +WHERE_TYPED_KERNEL(int64_t) //WHERE_TYPED_KERNEL(MLFloat16) //WHERE_TYPED_KERNEL(BFloat16) WHERE_TYPED_KERNEL(float) diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.cc b/onnxruntime/core/providers/cuda/cuda_allocator.cc index 44cbbd75d0fc2..5241545763c38 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.cc +++ b/onnxruntime/core/providers/cuda/cuda_allocator.cc @@ -61,8 +61,7 @@ void CUDAPinnedAllocator::Free(void* p) { } const OrtAllocatorInfo& CUDAPinnedAllocator::Info() const { - static constexpr OrtAllocatorInfo cuda_allocator_info(CUDA_PINNED, OrtDeviceAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0), 0, OrtMemTypeCPUOutput); - return cuda_allocator_info; + return info_; } FencePtr CUDAPinnedAllocator::CreateFence(const SessionState* session_state) { diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.h b/onnxruntime/core/providers/cuda/cuda_allocator.h index 06f6caa784c0e..2840dcb4088c3 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.h +++ b/onnxruntime/core/providers/cuda/cuda_allocator.h @@ -9,7 +9,7 @@ namespace onnxruntime { class CUDAAllocator : public IDeviceAllocator { public: - CUDAAllocator(int device_id) : info_(CUDA, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault) {} + CUDAAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault) {} virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; virtual const OrtAllocatorInfo& Info() const override; @@ -25,10 +25,14 @@ class CUDAAllocator : public IDeviceAllocator { //TODO: add a default constructor class CUDAPinnedAllocator : public IDeviceAllocator { public: + CUDAPinnedAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, device_id), device_id, OrtMemTypeCPUOutput) {} virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; virtual const OrtAllocatorInfo& Info() const override; virtual FencePtr CreateFence(const SessionState* session_state) override; + + private: + const OrtAllocatorInfo info_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6509cf01fdf9a..04a87a120bf0f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -52,7 +52,7 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(int device_id) { DeviceAllocatorRegistrationInfo default_allocator_info( {OrtMemTypeDefault, - [](int id) { return std::make_unique(id); }, std::numeric_limits::max()}); + [](int id) { return std::make_unique(id, CUDA); }, std::numeric_limits::max()}); allocator_ = CreateAllocator(default_allocator_info, device_id); } @@ -66,12 +66,17 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in CUDA_CALL_THROW(cudaSetDevice(device_id_)); DeviceAllocatorRegistrationInfo default_allocator_info( - {OrtMemTypeDefault, [](int id) { return std::make_unique(id); }, std::numeric_limits::max()}); + {OrtMemTypeDefault, [](int id) { return std::make_unique(id, CUDA); }, std::numeric_limits::max()}); InsertAllocator(CreateAllocator(default_allocator_info, device_id_)); DeviceAllocatorRegistrationInfo pinned_allocator_info( - {OrtMemTypeCPUOutput, [](int) { return std::make_unique(); }, std::numeric_limits::max()}); + {OrtMemTypeCPUOutput, [](int) { return std::make_unique(0, CUDA_PINNED); }, std::numeric_limits::max()}); InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_)); + + // TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU. + // This will be refactored/removed when allocator and execution provider are decoupled. + DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUInput, [](int) { return std::make_unique(std::make_unique("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUInput)); }, std::numeric_limits::max()}); + InsertAllocator(CreateAllocator(cpu_allocator_info)); } CUDAExecutionProvider::~CUDAExecutionProvider() { @@ -1013,6 +1018,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Note that nodes with only inputs from initializer would not be place on CUDA // Ideally, those nodes should be eliminated in constant folding bool should_force_outside = true; + bool all_input_are_initializer = true; node.ForEachWithIndex( node.InputDefs(), [&](const NodeArg& def, size_t index) { @@ -1020,12 +1026,17 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // The input is not a initializer and the input is from CPU // or the input declared as CPU memory and is from CPU // in that case we should still keep the node on CUDA - if ((!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) || + bool initializer_input = graph.GetInitializedTensor(def.Name(), initializer); + if ((!initializer_input && !defs_outside_cuda.count(&def)) || (defs_outside_cuda.count(&def) && cuda_kernel_def->kernel_def->IsInputOnCpu(index))) should_force_outside = false; + if (!initializer_input) { + all_input_are_initializer = false; + } return Status::OK(); }); - if (should_force_outside) { + // If all the inputs are initialier, we shouldn't force it to CPU + if (should_force_outside && !all_input_are_initializer) { force_outside = true; } } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 02a3ba6b694bb..bfd233b68b65e 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -97,13 +97,13 @@ class CudnnDropout final { return dropout_desc_; } - private: Status CreateDescriptorIfNeeded() { if (!dropout_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateDropoutDescriptor(&dropout_desc_)); return Status::OK(); } + private: cudnnDropoutDescriptor_t dropout_desc_; }; diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index 16f6246b3df37..6f679c8a6cbf7 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -92,6 +92,17 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, int KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); +#define BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + x, \ + kOnnxDomain, \ + ver, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + x); + #define BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(x, startver, endver, T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ x, \ @@ -127,6 +138,11 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, int BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, T) \ BINARY_ELEMENTWISE_COMPUTE(name, T) +#define BINARY_LOGICALOP_TYPED(name, ver, T) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \ + BINARY_ELEMENTWISE_COMPUTE(name, T) + + // since different ops has different types, we cannot use BINARY_OPS() directly // the postfix of means the types supported by the op: // B: uint8_t @@ -155,10 +171,15 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, int BINARY_OP_HFD(name, ver) #define BINARY_OP_REGISTER_OIL(name, ver) \ - BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \ + BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) +#define BINARY_LOGICALOP_REGISTER_OIL(name, ver) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, bool) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t) + #define BINARY_OP_REGISTER_HFD(name, ver) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \ @@ -171,6 +192,15 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, int BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) \ BINARY_OP_REGISTER_HFD(name, ver) +#define BINARY_LOGICALOP_REGISTER_UZILHFD(name, ver) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint32_t) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint64_t) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, float) \ + BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, double) + #define BINARY_OP_REGISTER_VERSIONED_HFD(name, startver, endver) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, float) \ @@ -188,9 +218,9 @@ BINARY_OP_UZILHFD(Sub, 7) BINARY_OP_UZILHFD(Mul, 7) BINARY_OP_UZILHFD(Div, 7) BINARY_OP_HFD(Pow, 7) -BINARY_OP_TYPED(And, 7, bool) -BINARY_OP_TYPED(Or, 7, bool) -BINARY_OP_TYPED(Xor, 7, bool) +BINARY_LOGICALOP_TYPED(And, 7, bool) +BINARY_LOGICALOP_TYPED(Or, 7, bool) +BINARY_LOGICALOP_TYPED(Xor, 7, bool) BINARY_OP_HFD(PRelu, 7) template @@ -440,7 +470,7 @@ Status Equal::ComputeInternal(OpKernelContext* context) const { BINARY_OP_REGISTER_UZILHFD(Sum, 8) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Sum, 6, 7) -BINARY_OP_REGISTER_UZILHFD(Greater, 9) +BINARY_LOGICALOP_REGISTER_UZILHFD(Greater, 9) BINARY_OP_REGISTER_OIL(Equal, 7) BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8) BINARY_OP_REGISTER_HFD(Max, 8) diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index e45eb16dc5508..2b13aa5882b57 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -15,7 +15,7 @@ void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, const cudnnTensorDescriptor_t x_desc, const cudnnFilterDescriptor_t w_desc, const cudnnFilterDescriptor_t filter_desc, - const void* w_data, + const void* reorganized_w_data, const int lin_layer_id, const T* pos, int& offset, @@ -27,9 +27,9 @@ void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, T* mem_offset; if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, w_data, lin_layer_id, filter_desc, (void**)&mem_offset); + cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, w_data, lin_layer_id, filter_desc, (void**)&mem_offset); + cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); } cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); @@ -42,25 +42,25 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, const cudnnTensorDescriptor_t x_desc, const cudnnFilterDescriptor_t w_desc, - void* w_data, + void* reorganized_w_data, const T* W_data, const T* R_data, const T* B_data) const { - //Onnx only support 1 layer int w_offset = 0; int r_offset = 0; int bias_offset = 0; - for (int layer = 0; layer < num_layers_ * num_directions_; ++layer) { + CudnnFilterDescriptor filter_desc; + for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], W_data, w_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], R_data, r_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); } } } @@ -68,34 +68,11 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, return Status::OK(); } -template -Status CudnnRnnBase::SetCudnnRnnDesc() { - typedef typename ToCudaType::MappedType CudaT; - - cudnnDirectionMode_t cudnn_direction = CUDNN_UNIDIRECTIONAL; - if (direction_ == "bidirectional") { - cudnn_direction = CUDNN_BIDIRECTIONAL; - } else if (direction_ == "forward") { - cudnn_direction = CUDNN_UNIDIRECTIONAL; - } else if (direction_ == "reverse") { - cudnn_direction = CUDNN_UNIDIRECTIONAL; - // need to reverse data - reverse_ = true; - } - - cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size_); - state_buffer_ = GetScratchBuffer(state_size_); - cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size_); - ORT_RETURN_IF_ERROR(rnn_desc_.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_, - cudnn_direction, rnn_mode_, CudnnTensor::GetDataType())); - - return Status::OK(); -} - template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, - IAllocatorUniquePtr& target_w_data, - CudnnFilterDescriptor& target_w_desc) const { + IAllocatorUniquePtr& reorganized_w_data, + CudnnFilterDescriptor& target_w_desc, + CudnnRNN& rnn_desc) const { typedef typename ToCudaType::MappedType CudaT; int64_t input_size = W->Shape()[2]; // RNN W[num_directions_, hidden_size_, input_size] @@ -117,20 +94,21 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType()); // Prepare the weight data - target_w_data = GetScratchBuffer(w_size * sizeof(T)); + reorganized_w_data = GetScratchBuffer(w_size * sizeof(T)); const T* W_data = W->template Data(); const T* R_data = R->template Data(); const T* B_data = B == nullptr ? nullptr : B->template Data(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_desc_, fake_x_desc, target_w_desc, - target_w_data.get(), W_data, R_data, B_data)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_desc, fake_x_desc, target_w_desc, + reorganized_w_data.get(), W_data, R_data, B_data)); return Status::OK(); } template Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { + typedef typename ToCudaType::MappedType CudaT; // Cache the weight const Tensor* W; const Tensor* R; @@ -140,10 +118,13 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); if (get_W && get_R) { + CudnnRNN tmp_rnn_desc; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc_, + cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc)); } weight_cached_ = true; } @@ -173,7 +154,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // optional outputs std::vector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); - std::vector dims_hxy({num_layers_ * num_directions_, batch_size, hidden_size_}); + std::vector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); std::vector dims_yc{num_directions_, batch_size, hidden_size_}; Tensor* Y = ctx->Output(Output_Index::Y, dims_Y); Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); @@ -198,16 +179,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Prepare the weight data - IAllocatorUniquePtr w_data; - CudnnFilterDescriptor w_desc; - if (!weight_cached_) { - const Tensor& W = *ctx->Input(RNN_Input_Index::W); - const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); - ReorganizeWeights(&W, &R, B, w_data, w_desc); - } - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->template Data(); if (reverse_) { @@ -239,16 +210,33 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data(); + CudnnRNN rnn_desc; + ORT_RETURN_IF_ERROR(rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc_, + cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); + + // Prepare the weight data + IAllocatorUniquePtr w_data; + CudnnFilterDescriptor w_desc; + if (!weight_cached_) { + const Tensor& W = *ctx->Input(RNN_Input_Index::W); + const Tensor& R = *ctx->Input(RNN_Input_Index::R); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc); + } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); + CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc_, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); auto workspace_cuda = GetScratchBuffer(workspace_bytes); + int32_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + int64_t zero_seq_index_cache_size = 0; if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), - rnn_desc_, + rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), x_data_input, @@ -267,13 +255,35 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { workspace_cuda.get(), workspace_bytes)); } else { + // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); + for (int i = 0; i < batch_size; ++i) { + if (0 == seq_len_array[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache_size = zero_seq_count * num_directions_; + zero_seq_index_cache.resize(zero_seq_index_cache_size); + for (int i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[zero_seq_count + i] = static_cast(batch_size + zero_seq_index_cache[i]); + } + } + CudnnDataTensor x_desc; - x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); + x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data()); CudnnDataTensor y_desc; - y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); + y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data()); CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), - rnn_desc_, + rnn_desc, x_desc, x_data_input, hx_desc, @@ -292,8 +302,13 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { nullptr, nullptr, nullptr, nullptr, workspace_cuda.get(), workspace_bytes)); + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { + SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data); + } return Status::OK(); } } @@ -327,10 +342,14 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } } + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { + SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data); + } + if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - auto count = sequence_lens->Shape().Size(); - CudaAsyncBuffer sequence_lens_buffer(this, GetDeviceId(), count); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t)); + CudaAsyncBuffer sequence_lens_buffer(this, GetDeviceId(), batch_size); + memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); sequence_lens_buffer.CopyToGpu(); RnnMaskImpl(gsl::narrow_cast(num_directions_), gsl::narrow_cast(seq_length), @@ -345,6 +364,24 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +template +void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, + const std::vector zero_seq_index_cache, + T* y_data, + T* y_h_data, + T* y_c_data) const { + typedef typename ToCudaType::MappedType CudaT; + CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, GetDeviceId(), zero_seq_index_cache_size); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + zero_seq_index_cache_async_buffer.CopyToGpu(); + MaskZeroSequences(gsl::narrow_cast(hidden_size_), + reinterpret_cast(y_data), + reinterpret_cast(y_h_data), + reinterpret_cast(y_c_data), + zero_seq_index_cache_async_buffer.GpuPtr(), + static_cast(zero_seq_index_cache_size)); +} + template class CudnnRnnBase; template class CudnnRnnBase; template class CudnnRnnBase; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 0afd35435cc7c..6b7f4e9c14f5e 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -21,26 +21,29 @@ enum RNN_Input_Index { initial_c = 6 }; +// Onnx RNN/GRU/LSTM only support 1 layer +const int RNN_NUM_LAYERS = 1; + class CudnnRNN { public: - CudnnRNN() : rnn_desc_(nullptr) { + CudnnRNN() : cudnn_rnn_desc_(nullptr) { } ~CudnnRNN() { - if (rnn_desc_ != nullptr) { - cudnnDestroyRNNDescriptor(rnn_desc_); - rnn_desc_ = nullptr; + if (cudnn_rnn_desc_ != nullptr) { + cudnnDestroyRNNDescriptor(cudnn_rnn_desc_); + cudnn_rnn_desc_ = nullptr; } } Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType) { - if (!rnn_desc_) - CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&rnn_desc_)); + if (!cudnn_rnn_desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor(cudnnHandle, - rnn_desc_, + cudnn_rnn_desc_, gsl::narrow_cast(hidden_size), num_layers, cudnn_dropout_desc, @@ -54,11 +57,11 @@ class CudnnRNN { } operator cudnnRNNDescriptor_t() const { - return rnn_desc_; + return cudnn_rnn_desc_; } private: - cudnnRNNDescriptor_t rnn_desc_; + cudnnRNNDescriptor_t cudnn_rnn_desc_; }; template @@ -68,23 +71,40 @@ class CudnnRnnBase : public CudaKernel { public: CudnnRnnBase(const OpKernelInfo& info) : CudaKernel{info} { reverse_ = false; - ORT_ENFORCE(info.GetAttr("direction", &direction_).IsOK()); - num_directions_ = direction_ == "bidirectional" ? 2 : 1; - ORT_ENFORCE(allowed_directions.find(direction_) != allowed_directions.end()); + std::string direction = "forward"; + direction = info.GetAttrOrDefault("direction", "forward"); + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + if (direction == "bidirectional") { + cudnn_direction_mode_ = CUDNN_BIDIRECTIONAL; + } else if (direction == "forward") { + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + } else if (direction == "reverse") { + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + // need to reverse data + reverse_ = true; + } + + num_directions_ = cudnn_direction_mode_ == CUDNN_BIDIRECTIONAL ? 2 : 1; + ORT_ENFORCE(allowed_directions.find(direction) != allowed_directions.end()); ORT_ENFORCE(info.GetAttr("hidden_size", &hidden_size_).IsOK() && hidden_size_ > 0); rnn_mode_ = CUDNN_LSTM; - num_layers_ = 1; weight_cached_ = false; w_data_cache_ = nullptr; + + size_t state_size; + cudnn_dropout_desc_.CreateDescriptorIfNeeded(); + cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size); + state_buffer_ = GetScratchBuffer(state_size); + cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size); } - Status SetCudnnRnnDesc(); - Status CacheCudnnRnnWeights(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* ctx) const override; + void SetRNNMode(cudnnRNNMode_t rnn_mode) { rnn_mode_ = rnn_mode; } + private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, @@ -97,7 +117,8 @@ class CudnnRnnBase : public CudaKernel { Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, IAllocatorUniquePtr& target_w_data, - CudnnFilterDescriptor& target_w_desc) const; + CudnnFilterDescriptor& target_w_desc, + CudnnRNN& rnn_desc) const; void SetWeightBias(const cudnnHandle_t handle, const cudnnRNNDescriptor_t rnn_desc, @@ -111,27 +132,32 @@ class CudnnRnnBase : public CudaKernel { int& offset, bool is_matrix) const; + void SetZeroSequences(const int64_t zero_seq_index_cache_size, + const std::vector zero_seq_index_cache, + T* y_data, + T* y_h_data, + T* y_c_data) const; + protected: - int64_t num_directions_; - // required - int64_t hidden_size_; - cudnnRNNMode_t rnn_mode_; + // W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor std::vector W_lin_layer_id_; std::vector R_lin_layer_id_; - CudnnRNN rnn_desc_; - bool reverse_; - int num_layers_; private: - // optional - std::string direction_; + cudnnDirectionMode_t cudnn_direction_mode_; + bool reverse_; + int64_t num_directions_; + // hidden_size_ from attribute + int64_t hidden_size_; + cudnnRNNMode_t rnn_mode_; + // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; - CudnnDropout cudnn_dropout_desc_; - CudnnFilterDescriptor filter_desc_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; + + // cudnn_dropout_desc_ is a cache, never to be changed IAllocatorUniquePtr state_buffer_; - size_t state_size_; + CudnnDropout cudnn_dropout_desc_; enum Output_Index { Y = 0, diff --git a/onnxruntime/core/providers/cuda/rnn/gru.h b/onnxruntime/core/providers/cuda/rnn/gru.h index 43a0ba4ab5878..ab9dabff5db36 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.h +++ b/onnxruntime/core/providers/cuda/rnn/gru.h @@ -15,8 +15,7 @@ template class GRU final : public CudnnRnnBase { public: GRU(const OpKernelInfo& info) : CudnnRnnBase(info) { - CudnnRnnBase::rnn_mode_ = CUDNN_GRU; - CudnnRnnBase::SetCudnnRnnDesc(); + CudnnRnnBase::SetRNNMode(CUDNN_GRU); // ONNX W layout is Wzrh, WBzrh, mapping to RNNLinLayerMatrixParams the linLayerID is 1, 0, 2 CudnnRnnBase::W_lin_layer_id_.assign({1, 0, 2}); diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.h b/onnxruntime/core/providers/cuda/rnn/lstm.h index 3ba719d61750d..3ed12cfa7fff9 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.h +++ b/onnxruntime/core/providers/cuda/rnn/lstm.h @@ -13,8 +13,7 @@ class LSTM final : public CudnnRnnBase { public: LSTM(const OpKernelInfo& info) : CudnnRnnBase(info) { - CudnnRnnBase::rnn_mode_ = CUDNN_LSTM; - CudnnRnnBase::SetCudnnRnnDesc(); + CudnnRnnBase::SetRNNMode(CUDNN_LSTM); // ONNX W layout is W[iofc], WB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 0, 3, 1, 2 CudnnRnnBase::W_lin_layer_id_.assign({0, 3, 1, 2}); diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index 246e8d1062df0..dbb0d2843fe11 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -20,11 +20,9 @@ class RNN final : public CudnnRnnBase { std::vector activations_; ORT_ENFORCE(info.GetAttrs("activations", activations_).IsOK()); if (activations_[0] == "Relu") - CudnnRnnBase::rnn_mode_ = CUDNN_RNN_RELU; + CudnnRnnBase::SetRNNMode(CUDNN_RNN_RELU); else if (activations_[0] == "Tanh") - CudnnRnnBase::rnn_mode_ = CUDNN_RNN_TANH; - - CudnnRnnBase::SetCudnnRnnDesc(); + CudnnRnnBase::SetRNNMode(CUDNN_RNN_TANH); // ONNX W mapping to RNNLinLayerMatrixParams the linLayerID is 0 CudnnRnnBase::W_lin_layer_id_.assign({0}); diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index ae210ae6818de..930c3a4ddd343 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -133,6 +133,48 @@ void RnnMaskImpl(const int32_t num_directions, div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); } +template +__global__ void _MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache, + const CUDA_LONG N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + int32_t zero_seq_offset = zeor_seq_index_cache[id] * hidden_size; + + if (y_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_output_data[zero_seq_offset + i] = 0; + } + } + + if (y_h_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_h_output_data[zero_seq_offset + i] = 0; + } + } + + if (y_c_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_c_output_data[zero_seq_offset + i] = 0; + } + } +} + +template +void MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache, + const size_t N) { + int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + _MaskZeroSequences<<>>( + hidden_size, y_output_data, y_h_output_data, y_c_output_data, zeor_seq_index_cache, (CUDA_LONG)N); +} + #define SPECIALIZED_RNN_IMPL(T) \ template void RnnMaskImpl(const int32_t num_directions, \ const int32_t seq_length, \ @@ -153,7 +195,13 @@ void RnnMaskImpl(const int32_t num_directions, const int32_t hidden_size,\ const T* data, \ T* reordered_data, \ - const size_t N); + const size_t N); \ +template void MaskZeroSequences(const int32_t hidden_size, \ + T* y_output_data, \ + T* y_h_output_data, \ + T* y_c_output_data, \ + const int32_t* zeor_seq_index_cache, \ + const size_t N); SPECIALIZED_RNN_IMPL(half) SPECIALIZED_RNN_IMPL(float) diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index d25d71aed3fb1..78ceabf23bf2e 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -34,5 +34,12 @@ void RnnMaskImpl(const int32_t num_directions, T* y_h_output_data, const size_t N); +template +void MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache_async_buffer, + const size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/compress.cc b/onnxruntime/core/providers/cuda/tensor/compress.cc index 4e33a421846b9..9e23ad6a5fc1a 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress.cc +++ b/onnxruntime/core/providers/cuda/tensor/compress.cc @@ -13,7 +13,8 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 9, kCudaExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Compress); Status Compress::ComputeInternal(OpKernelContext* ctx) const { diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index f8df8a9689f02..55d7fcaf01f49 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -29,8 +29,13 @@ __global__ void _ResizeNearestKernel(const size_t rank, output_data[id] = input_data[input_index]; } +// The following method supports a 4-D input in 'Linear mode' +// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes +// the scale values for the outermost 2 dimensions are 1. +// This is the common use-case where the 4-D input (batched multi-channel images) +// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template -__global__ void _ResizeBilinearKernel(const int64_t input_dim2, +__global__ void _ResizeBilinear4DInputKernel(const int64_t input_dim2, const int64_t* input_pitches, const fast_divmod* output_div_pitches, const float* scales, @@ -90,6 +95,62 @@ __global__ void _ResizeBilinearKernel(const int64_t input_dim2, x11 * static_cast(y_offset_0 * x_offset_0); } +// The following method supports a 2-D input in 'Linear mode' +template +__global__ void _ResizeBilinear2DInputKernel(const int64_t input_dim0, + const int64_t* input_pitches, + const fast_divmod* output_div_pitches, + const float* scales, + const T* input_data, + T* output_data, + const size_t N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_index = 0; + + int mod; + int index_of_dim0, index_of_dim1; + output_div_pitches[0].divmod(id, index_of_dim0, mod); + index_of_dim1 = mod; + int index_of_input_dim0, index_of_input_dim1; + float x_offset_0, y_offset_0, x_offset_1, y_offset_1; + index_of_input_dim0 = static_cast(index_of_dim0 / scales[0]); + index_of_input_dim1 = static_cast(index_of_dim1 / scales[1]); + input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1; + + T x00 = input_data[input_index]; + T x10, x01, x11; + + bool end_of_dim0 = false, end_of_dim1 = false; + if (index_of_input_dim0 == (input_dim0 - 1)) { + // It's the end in dimension 0 + x01 = x00; + end_of_dim0 = true; + } else { + x01 = input_data[input_index + input_pitches[0]]; + } + + if (index_of_input_dim1 == (input_pitches[0] - 1)) { + // It's the end in dimension 1 + x10 = x00; + x11 = x01; + end_of_dim1 = true; + } else { + x10 = input_data[input_index + 1]; + x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1]; + } + + y_offset_0 = end_of_dim0 ? 0.5f : index_of_dim0 / scales[0] - index_of_input_dim0; + y_offset_1 = 1.0f - y_offset_0; + x_offset_0 = end_of_dim1 ? 0.5f : index_of_dim1 / scales[1] - index_of_input_dim1; + x_offset_1 = 1.0f - x_offset_0; + + output_data[id] = + x00 * static_cast(y_offset_1 * x_offset_1) + + x01 * static_cast(y_offset_0 * x_offset_1) + + x10 * static_cast(y_offset_1 * x_offset_0) + + x11 * static_cast(y_offset_0 * x_offset_0); +} + template void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode, const size_t rank, @@ -105,8 +166,12 @@ void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode, _ResizeNearestKernel<<>>( rank, input_pitches, output_div_pitches, scales_vals, input_data, output_data, N); - } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { - _ResizeBilinearKernel<<>>( + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) { + _ResizeBilinear4DInputKernel<<>>( + input_dim2, input_pitches, output_div_pitches, scales_vals, + input_data, output_data, N); + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) { + _ResizeBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_vals, input_data, output_data, N); } diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 390d9139de58d..854c784c8a851 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -17,7 +17,8 @@ namespace cuda { kCudaExecutionProvider, \ KernelDefBuilder() \ .InputMemoryType(1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Tile); template diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 88248983d70ae..3a9eb36c22f41 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -38,10 +38,21 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& X_dims = X->Shape().GetDims(); auto rank = X_dims.size(); if (rank == 0) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor cannot be scalar."); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != scales.size()) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales."); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor's dimension does not match the scales." : + "Upsample: input tensor's dimension does not match the scales."); + + if (UpsampleMode::LINEAR == mode_ && rank != 4 && rank != 2) { + std::ostringstream oss; + oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the "; + oss << (is_resize ? "Resize operator" : "Upsample operator"); + return Status(ONNXRUNTIME, FAIL, oss.str()); + } std::vector Y_dims; for (std::size_t i = 0; i < rank; i++) { @@ -69,21 +80,12 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vectorShape().Size(); - if (UpsampleMode::LINEAR == mode_) { - if (rank != 4) - if (is_resize) { - return Status(ONNXRUNTIME, FAIL, "Resize: linear mode only supports 4-D tensor with NCHW layout"); - } else { - return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode only supports 4-D tensor with NCHW layout"); - } - } - if (is_resize) { CudaAsyncBuffer scales_vals(this, device_id, scales); scales_vals.CopyToGpu(); ResizeImpl(mode_, rank, - (UpsampleMode::LINEAR == mode_) ? X_dims[2] : 0, + (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, input_strides.GpuPtr(), output_div_pitches.GpuPtr(), scales_vals.GpuPtr(), @@ -101,7 +103,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector -__global__ void _UpampleBilinearKernel(const int64_t input_dim2, +__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, const int64_t* input_pitches, const fast_divmod* output_div_pitches, const fast_divmod* scales_div, @@ -90,6 +95,59 @@ __global__ void _UpampleBilinearKernel(const int64_t input_dim2, output_data[id] = y0 + static_cast(x_offset_T * (y1 - y0) / scales_div3_T); } +// The following method supports a 2-D input in 'Linear mode' +template +__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, + const int64_t* input_pitches, + const fast_divmod* output_div_pitches, + const fast_divmod* scales_div, + const T* input_data, + T* output_data, + const size_t N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_index = 0; + + int mod; + int index_of_dim0, index_of_dim1; + output_div_pitches[0].divmod(id, index_of_dim0, mod); + index_of_dim1 = mod; + int index_of_input_dim0, index_of_input_dim1, x_offset, y_offset; + scales_div[0].divmod(index_of_dim0, index_of_input_dim0, y_offset); + scales_div[1].divmod(index_of_dim1, index_of_input_dim1, x_offset); + + input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1; + + T x00 = input_data[input_index]; + T x10, x01, x11; + + bool end_of_dim0 = false; + if (index_of_input_dim0 == (input_dim0 - 1)) { + // It's the end in dimension 0 + x01 = x00; + end_of_dim0 = true; + } else { + x01 = input_data[input_index + input_pitches[0]]; + } + + if (index_of_input_dim1 == (input_pitches[0] - 1)) { + // It's the end in dimension 1 + x10 = x00; + x11 = x01; + } else { + x10 = input_data[input_index + 1]; + x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1]; + } + + T y_offset_T = static_cast(y_offset); + T x_offset_T = static_cast(x_offset); + T scales_div0_T = static_cast(scales_div[0].d_); + T scales_div1_T = static_cast(scales_div[1].d_); + T y0 = x00 + static_cast(y_offset_T * (x01 - x00) / scales_div0_T); + T y1 = x10 + static_cast(y_offset_T * (x11 - x10) / scales_div0_T); + + output_data[id] = y0 + static_cast(x_offset_T * (y1 - y0) / scales_div1_T); +} + template void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode, const size_t rank, @@ -105,8 +163,12 @@ void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode, _UpampleNearestKernel<<>>( rank, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); - } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { - _UpampleBilinearKernel<<>>( + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) { + _UpampleBilinear4DInputKernel<<>>( + input_dim2, input_pitches, output_div_pitches, scales_div, + input_data, output_data, N); + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) { + _UpampleBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index 93cb36116f964..a2908888c5b49 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -101,7 +101,7 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_ index++; node = graph_viewer.GetNode(index); } - if (node->InputDefs()[0]->Type() != nullptr) + if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos; } @@ -357,8 +357,8 @@ void MKLDNNExecutionProvider::CreateMetaDef(const onnxruntime::GraphViewer& grap std::vector>& result) const { std::string graph_fused_nodes; std::string node_list; - std::string subgraph_id = std::to_string(sub_var.subgraph_index); - sub_var.subgraph_index++; + std::string subgraph_id = std::to_string(subgraph_index_); + subgraph_index_++; // This is a list of initializers that subgraph considers as constants. // Example weights, reshape shape etc. @@ -378,7 +378,7 @@ void MKLDNNExecutionProvider::CreateMetaDef(const onnxruntime::GraphViewer& grap auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); meta_def->attributes["initializers"] = initializers; - meta_def->name = "MkldnnCustomOp" + std::to_string(sub_var.subgraph_index); + meta_def->name = "MkldnnCustomOp" + std::to_string(subgraph_index_); meta_def->domain = kMSDomain; meta_def->since_version = 1; meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.h b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.h index 2869698568bde..a57f290689382 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.h +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.h @@ -147,6 +147,8 @@ class MKLDNNExecutionProvider : public IExecutionProvider { } private: + mutable int subgraph_index_ = 0; + // supported MklDnn Operators std::set mkldnn_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum", "AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN"}; diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc b/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc index c94060d5b2450..2cc7e112a29e2 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc @@ -27,6 +27,7 @@ std::unique_ptr MkldnnProviderFactory::CreateProvider() { std::shared_ptr CreateExecutionProviderFactory_Mkldnn(int device_id) { return std::make_shared(device_id); + //TODO: This is apparently a bug. The consructor parameter is create-arena-flag, not the device-id } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/mkldnn/subgraph/subgraph.h b/onnxruntime/core/providers/mkldnn/subgraph/subgraph.h index 6e3f967dab65d..b63692bce04ec 100644 --- a/onnxruntime/core/providers/mkldnn/subgraph/subgraph.h +++ b/onnxruntime/core/providers/mkldnn/subgraph/subgraph.h @@ -48,12 +48,8 @@ struct Subgraph { std::vector outputs; std::vector outputs_as_input_other_node; std::vector subgraph_node_indexes; - int subgraph_index = 0; - SubgraphVariables() { - subgraph_index = 0; - } - void Reset() { + void Reset() { subgraph_node_indexes.clear(); inputs.clear(); outputs.clear(); diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc index 5ac6354bbb11b..326e878cbcd95 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.cc @@ -25,20 +25,23 @@ namespace onnxruntime { namespace ngraph_ep { +#define NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE 500 + static bool check_ngraph_dump_ops() { #ifdef _WIN32 size_t env_name_len = 0; char* env_name = nullptr; - return (_dupenv_s(&env_name, &env_name_len, "ONNXRUNTIME_NGRAPH_DUMP_OPS") == 0); + return (_dupenv_s(&env_name, &env_name_len, "ONNXRUNTIME_NGRAPH_DUMP_OPS") == 0 && env_name != nullptr); #else return (std::getenv("ONNXRUNTIME_NGRAPH_DUMP_OPS") != nullptr); #endif } -NGRAPHCustomOp::NGRAPHCustomOp(const ComputeContext* context, const ONNX_NAMESPACE::ModelProto& model_proto, - const std::shared_ptr& ng_backend) - : ng_backend_{ng_backend}, - model_proto_{model_proto} { +NGRAPHCustomOp::NGRAPHCustomOp(const ComputeContext* context, + const ONNX_NAMESPACE::ModelProto& model_proto, + const std::shared_ptr& ng_backend) : + ng_backend_{ng_backend}, model_proto_{model_proto} +{ allocate_func_ = context->allocate_func; release_func_ = context->release_func; allocator_ = context->allocator_handle; @@ -59,7 +62,6 @@ NGRAPHCustomOp::~NGRAPHCustomOp() { //This method gets called in critical path of execution: Optimize void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* context) const { Ort::CustomOpApi ort{*api}; - LOGS_DEFAULT(INFO) << "nGraph compiling customOp: " << name_; size_t num_inputs = ort.KernelContext_GetInputCount(context); @@ -80,7 +82,45 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con uniq_input_shape.append(reinterpret_cast(tensor_shape.data()), ndim * sizeof(int64_t)); } - auto it = ng_exe_map_.insert({uniq_input_shape, nullptr}); //TODO: Limit the size of map with configurable size. + // Get cache size from environment + std::string tempSize; + #ifdef _WIN32 + char *buf{nullptr}; + size_t bufSize = 0; + if (!_dupenv_s(&buf, &bufSize, "ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE") && buf) { + tempSize = buf; + free(buf); + } + #else + if (std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE")) { + tempSize = std::getenv("ONNXRUNTIME_NGRAPH_LRU_CACHE_SIZE"); + } + #endif + size_t cacheSize = tempSize.empty() ? NGRAPH_EP_LRU_CACHE_DEFAULT_SIZE : std::stoi(tempSize); + + // Not in cache + if (ng_exe_map_.find(uniq_input_shape) == ng_exe_map_.end()) { + // Check if full + if (keyCache.size() == cacheSize) { + // Delete least recently used element + std::string last = keyCache.back(); + + // Pop the last elmeent + keyCache.pop_back(); + + // Erase the last element from cache + ng_exe_map_.erase(ng_exe_map_.find(last)); + } + } + + // Found in cache + else { + keyCache.remove(uniq_input_shape); + } + + // update reference + keyCache.push_front(uniq_input_shape); + auto it = ng_exe_map_.insert({uniq_input_shape, nullptr}); //ng_exe with current shape already exists if (!it.second) { @@ -88,6 +128,9 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con return; } else { auto graph_proto = model_proto_.mutable_graph(); + + LOGS_DEFAULT(INFO) << "[NGRAPHCustomOp] Compiling customOp: " << name_; + // Clear previous shapes if any and set new input shapes for (size_t i = 0; i < num_inputs; i++) { auto g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); @@ -108,12 +151,12 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con try { ng_function = ngraph::onnx_import::import_onnx_model(model_stream); } catch (const std::exception& exp) { - LOGS_DEFAULT(FATAL) << "[" << name_ << "] " - << "Exception while converting onnx to nGraph: " << std::string(exp.what()); + LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " + << "Exception while importing model to nGraph: " << std::string(exp.what()); throw; } catch (...) { - LOGS_DEFAULT(FATAL) << "[" << name_ << "] " - << "Unknown exception while converting onnx to nGraph"; + LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " + << "Unknown exception while importing model to nGraph"; throw; } @@ -125,9 +168,10 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con try { ng_curr_exe_ = ng_backend_->compile(ng_function); } catch (const std::exception& exp) { - LOGS_DEFAULT(FATAL) << "Exception while compiling nGraph Op: " << name_ << std::string(exp.what()); + LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " + << "Exception while compiling ngraph::Function: " << std::string(exp.what()); } catch (...) { - LOGS_DEFAULT(FATAL) << "Unknown exception while compiling nGraph Op: " << name_; + LOGS_DEFAULT(FATAL) << "[NGRAPHCustomOp] " << " - " << name_ << " - " << "Unknown exception while compiling ngraph::Function"; } it.first->second = ng_curr_exe_; } @@ -137,11 +181,11 @@ void NGRAPHCustomOp::Initialize(const OrtCustomOpApi* api, OrtKernelContext* con Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const { Ort::CustomOpApi ort{*api}; - //TODO: Minimize locked region - std::lock_guard lock(compute_lock_); - // Initialize nGraph function if it is not already initialized. - Initialize(api, context); + { + std::lock_guard lock(compute_lock_); + Initialize(api, context); + } ORT_ENFORCE(ng_curr_exe_ != nullptr); @@ -154,12 +198,13 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont for (const auto& ng_param : ng_curr_exe_->get_parameters()) { const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index++); void* input_data = const_cast(ort.GetTensorData(input_tensor)); + std::lock_guard lock(compute_lock_); ng_inputs.emplace_back(ng_backend_->create_tensor(ng_param->get_output_element_type(0), ng_param->get_output_shape(0), input_data)); } } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception while copying input data to nGraph: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while copying input data to nGraph: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown exception while copying input data to nGraph"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while copying input data to nGraph"); } // Initialize output tensors @@ -173,22 +218,24 @@ Status NGRAPHCustomOp::Compute(const OrtCustomOpApi* api, OrtKernelContext* cont std::vector ort_shape{shape.begin(), shape.end()}; OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index++, ort_shape.data(), ort_shape.size()); void* output_data = ort.GetTensorMutableData(output_tensor); + std::lock_guard lock(compute_lock_); ng_outputs.emplace_back(ng_backend_->create_tensor(dtype, shape, output_data)); } } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception while creating nGraph output Tensor: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while creating nGraph output Tensor: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown exception while creating nGraph output Tensor"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while creating nGraph output Tensor"); } // Run the graph through nGraph. try { + std::lock_guard lock(compute_lock_); if (!ng_curr_exe_->call(ng_outputs, ng_inputs)) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error while executing nGraph computation"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Error while executing nGraph computation"); } catch (const std::exception& exp) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception while executing nGraph computation: " + std::string(exp.what())); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Exception while executing nGraph computation: " + std::string(exp.what())); } catch (...) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown exception while executing nGraph computation"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, name_ + ": Unknown exception while executing nGraph computation"); } return Status::OK(); diff --git a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h index 6661fdb378e56..ad9955872d7fb 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_custom_op.h +++ b/onnxruntime/core/providers/ngraph/ngraph_custom_op.h @@ -25,7 +25,9 @@ namespace ngraph_ep { class NGRAPHCustomOp { public: - NGRAPHCustomOp(const ComputeContext* context, const ONNX_NAMESPACE::ModelProto& model_proto, const std::shared_ptr& ng_backend); + NGRAPHCustomOp(const ComputeContext* context, + const ONNX_NAMESPACE::ModelProto& model_proto, + const std::shared_ptr& ng_backend); Status Compute(const OrtCustomOpApi* api, OrtKernelContext* context) const; @@ -54,7 +56,8 @@ class NGRAPHCustomOp { key = [3,1,2,3,2,4,5] */ mutable std::unordered_map> ng_exe_map_; - + mutable std::list keyCache; + mutable std::mutex compute_lock_; mutable ONNX_NAMESPACE::ModelProto model_proto_; diff --git a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc index 459deae2c81f9..60fad0071ccf6 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc +++ b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.cc @@ -33,19 +33,36 @@ constexpr const char* NGRAPH = "nGraph"; NGRAPHExecutionProvider::NGRAPHExecutionProvider(const NGRAPHExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kNGraphExecutionProvider} { - DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault, - [](int) { return std::make_unique(std::make_unique(NGRAPH, OrtAllocatorType::OrtDeviceAllocator)); }, - std::numeric_limits::max()}); + + ORT_ENFORCE(info.ng_backend_type == "CPU", "nGraph Execution Provider for onnxruntime currently is only supported for CPU backend."); + + auto default_allocator_factory = [](int) { + auto allocator_info = std::make_unique(NGRAPH, OrtAllocatorType::OrtDeviceAllocator); + return std::make_unique(std::move(allocator_info)); + }; + + DeviceAllocatorRegistrationInfo default_allocator_info{ + OrtMemTypeDefault, + std::move(default_allocator_factory), + std::numeric_limits::max() + }; InsertAllocator(CreateAllocator(default_allocator_info)); - DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUOutput, - [](int) { return std::make_unique(std::make_unique(NGRAPH, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); }, - std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(cpu_allocator_info)); + auto cpu_allocator_factory = [](int) { + auto allocator_info = std::make_unique( + NGRAPH, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput); + return std::make_unique(std::move(allocator_info)); + }; - ORT_ENFORCE(info.ng_backend_type == "CPU", "nGraph Execution Provider for onnxruntime currently is only supported for CPU backend."); + DeviceAllocatorRegistrationInfo cpu_allocator_info{ + OrtMemTypeCPUOutput, + std::move(cpu_allocator_factory), + std::numeric_limits::max() + }; + + InsertAllocator(CreateAllocator(cpu_allocator_info)); try { ng_backend_ = ngraph::runtime::Backend::create(info.ng_backend_type); @@ -57,25 +74,6 @@ NGRAPHExecutionProvider::NGRAPHExecutionProvider(const NGRAPHExecutionProviderIn } } -/** - * Checks if a tensor represented by srcLocation can be copied into the dstLocation tensor - * @param src_location result of Location().name call on the source tensor - * @param dst_location result of Location().name call on the destination tensor - * @return true if src and dest locations combination allows copying - */ -bool TensorCopyPossible(const std::string& src_location, const std::string& dst_location) { - // contains allowed combinations of source and destination locations for tensors copying purposes - // the first element of a pair denotes a source, the second - destination - static const std::map allowed_copy_directions = { - {NGRAPH, CPU}, {NGRAPH, NGRAPH}, {CPU, NGRAPH}}; - - // copying of tensors is allowed only if the params match any of the allowed combinations - return std::any_of(allowed_copy_directions.begin(), - allowed_copy_directions.end(), [&](const auto& copy_direction) { - return src_location == copy_direction.first && dst_location == copy_direction.second; - }); -} - // Returns true only if op is in a mode that is not currently supported static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer& graph_viewer) { const auto& optype = node->OpType(); @@ -131,11 +129,6 @@ static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer return true; } } - } else if (optype == "Cast") { - //support of casting to bool in nGraph is in progress - const auto& attributes = node->GetAttributes(); - const auto to_attr = attributes.find("to"); - return to_attr->second.i() == ONNX_NAMESPACE::TensorProto::BOOL; } else if (optype == "Slice") { //Slice in opset 10 is currently not supported. //unsupported inputs: starts, ends, axes, steps @@ -164,6 +157,63 @@ static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer if (ceil_attr != attributes.end() && ceil_attr->second.i() != 0) { return true; } + } else if (optype == "Split") { + const auto& attributes = node->GetAttributes(); + const auto split_attr = attributes.find("split"); + + if (split_attr != attributes.end()) { + // split implementation contains a bug that doesn't throw for incorrect split values + // disabling temporarily until it's fixed in the next release of nGraph + const auto splits = split_attr->second.ints(); + return std::any_of(std::begin(splits), std::end(splits), + [](const auto split) { return split <= 0; }); + } + } else if (optype == "QLinearMatMul") { + const auto& a_zero_point = node->InputDefs()[2]; + const auto& b_zero_point = node->InputDefs()[5]; + const auto& y_zero_point = node->InputDefs()[7]; + + bool non_const_zero_point = false; + + // check if any of the zero points is NOT in the initializers list + non_const_zero_point |= initializers.find(a_zero_point->Name()) == initializers.end(); + non_const_zero_point |= initializers.find(b_zero_point->Name()) == initializers.end(); + non_const_zero_point |= initializers.find(y_zero_point->Name()) == initializers.end(); + + // QLinearMatMul is not supported if any of the zero points is a dynamic input + return non_const_zero_point; + } else if (optype == "MatMulInteger") { + // all MatMulInteger zero points need to be constants + const auto inputs = node->InputDefs(); + if (inputs.size() == 3) { + const auto& a_zero_point = node->InputDefs()[2]; + + // not found in initializers -> not const + return initializers.find(a_zero_point->Name()) == initializers.end(); + } else if (inputs.size() == 4) { + const auto& a_zero_point = node->InputDefs()[2]; + const auto& b_zero_point = node->InputDefs()[3]; + + // not found in initializers -> not const + return initializers.find(a_zero_point->Name()) == initializers.end() || + initializers.find(b_zero_point->Name()) == initializers.end(); + } // else -> azp & bzp are 0 by default according to ONNX spec + } else if (optype == "ConvInteger") { + // all ConvInteger zero points need to be constants + const auto inputs = node->InputDefs(); + if (inputs.size() == 3) { + const auto& x_zero_point = node->InputDefs()[2]; + + // not found in initializers -> not const + return initializers.find(x_zero_point->Name()) == initializers.end(); + } else if (inputs.size() == 4) { + const auto& x_zero_point = node->InputDefs()[2]; + const auto& w_zero_point = node->InputDefs()[3]; + + // not found in initializers -> not const + return initializers.find(x_zero_point->Name()) == initializers.end() || + initializers.find(w_zero_point->Name()) == initializers.end(); + } // else -> xzp & wzp are 0 by default according to ONNX spec } //Op doesn't fall into known any of unsupported modes. @@ -237,21 +287,10 @@ static void AppendClusterToSubGraph(const std::vector& nodes, const onnxruntime::GraphViewer& graph_viewer, const std::vector& inputs, const std::vector& outputs, - const std::unordered_set& ng_required_initializers, std::vector>& result) { static size_t op_counter = 0; - // Create ng_required_initializers attribute of NGraphCustomOp - ONNX_NAMESPACE::AttributeProto initializers; - initializers.set_name("initializers"); - initializers.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSORS); - for (const auto& init : ng_required_initializers) { - auto tensor = initializers.add_tensors(); - *tensor = *(graph_viewer.GetAllInitializedTensors().at(init)); - } - auto meta_def = std::make_unique(); - meta_def->attributes["initializers"] = initializers; meta_def->name = "NGRAPHCustomOp_" + std::to_string(++op_counter); meta_def->domain = kNGraphDomain; meta_def->since_version = 1; @@ -259,6 +298,13 @@ static void AppendClusterToSubGraph(const std::vector& nodes, meta_def->inputs = inputs; meta_def->outputs = outputs; + //store the name of the graph this node belongs to - used to retrieve graph initializers from the cache + ONNX_NAMESPACE::AttributeProto graph_name; + graph_name.set_name("graph_name"); + graph_name.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); + graph_name.set_s(graph_viewer.Name()); + meta_def->attributes["graph_name"] = graph_name; + std::unique_ptr sub_graph = std::make_unique(); sub_graph->nodes = nodes; sub_graph->SetMetaDef(meta_def); @@ -274,7 +320,7 @@ static std::map> GetNgSupportedOps(const int std::map> ng_supported_ops; ng_supported_ops.emplace(kOnnxDomain, ngraph::onnx_import::get_supported_operators(onnx_opset, kOnnxDomain)); - const std::set ng_disabled_ops = {}; //Place-holder for ops not supported. + const std::set ng_disabled_ops = {"LSTM", "Gather"}; //Place-holder for ops not supported. for (const auto& disabled_op : ng_disabled_ops) { ng_supported_ops.at(kOnnxDomain).erase(disabled_op); @@ -283,7 +329,8 @@ static std::map> GetNgSupportedOps(const int return ng_supported_ops; } -static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& ng_required_initializers) { +static std::vector +GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& ng_required_initializers) { const auto ng_supported_ops = GetNgSupportedOps(GetOnnxOpSet(graph_viewer)); std::vector unsupported_nodes_idx; @@ -303,10 +350,12 @@ static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph return unsupported_nodes_idx; } -/* Returns a vector clusters(or node_idx). For each unsupported node, the graph is split into 3 parts. - supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph -*/ -static std::vector> GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { +/** + * Returns a vector clusters(or node_idx). For each unsupported node, the graph is split into 3 parts. + * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph + */ +static std::vector> +GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { std::vector> ng_clusters; auto prev = topological_order.begin(); @@ -457,7 +506,7 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); }); // Create and add this graph to result. - AppendClusterToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer, inputs, outputs, ng_required_initializers, result); + AppendClusterToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer, inputs, outputs, result); } else { // unsupported_nodes_idx.empty() const auto ng_clusters = GetPartitionedClusters(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); @@ -467,7 +516,7 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie GetInputsOutputsOfCluster(graph_viewer, this_cluster, ng_required_initializers, cluster_inputs, cluster_outputs); if (!cluster_inputs.empty()) { - AppendClusterToSubGraph(this_cluster, graph_viewer, cluster_inputs, cluster_outputs, ng_required_initializers, result); + AppendClusterToSubGraph(this_cluster, graph_viewer, cluster_inputs, cluster_outputs, result); } } } @@ -476,35 +525,21 @@ NGRAPHExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node) { - const auto& attributes = fused_node->GetAttributes(); - const auto& initializers = attributes.at("initializers").tensors(); - - ONNX_NAMESPACE::ModelProto model_proto; - auto graph_proto = model_proto.mutable_graph(); - const auto& fused_graph = fused_node->GetFunctionBody()->Body(); + const auto* node_function = fused_node->GetFunctionBody(); - for (const auto& node : fused_graph.Nodes()) { - node.ToProto(*(graph_proto->add_node())); - } + ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name()); - for (const auto& input : fused_node->InputDefs()) { - auto valueInfoProto = graph_proto->add_input(); - *valueInfoProto = input->ToProto(); - } + const Graph& node_subgraph = node_function->Body(); + onnxruntime::Model model{node_subgraph.Name(), true}; - for (const auto& output : fused_node->OutputDefs()) { - auto valueInfoProto = graph_proto->add_output(); - *valueInfoProto = output->ToProto(); - } + ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - for (const auto& initializer : initializers) { - graph_proto->add_initializer()->CopyFrom(initializer); - } + *(model_proto.mutable_graph()) = node_subgraph.ToGraphProto(); auto opset = model_proto.add_opset_import(); opset->set_domain(kOnnxDomain); - opset->set_version(fused_graph.DomainToVersionMap().at(kOnnxDomain)); - model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + opset->set_version(node_subgraph.DomainToVersionMap().at(kOnnxDomain)); return model_proto; } @@ -512,14 +547,14 @@ static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime:: Status NGRAPHExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { for (const auto& fused_node : fused_nodes) { - auto model_proto = GetModelProtoFromFusedNode(fused_node); - NodeComputeInfo compute_info; // Local copy of backend since, class members cannot be captured. auto ngraph_backend = ng_backend_; - compute_info.create_state_func = [model_proto, ngraph_backend](ComputeContext* context, FunctionState* state) { - auto* p = new onnxruntime::ngraph_ep::NGRAPHCustomOp(context, model_proto, ngraph_backend); + compute_info.create_state_func = [model_proto = GetModelProtoFromFusedNode(fused_node), ngraph_backend] + (ComputeContext* context, FunctionState* state) + { + auto* p = new ngraph_ep::NGRAPHCustomOp(context, model_proto, ngraph_backend); *state = p; return 0; }; diff --git a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.h b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.h index f4081a43a555b..daade7022d44d 100644 --- a/onnxruntime/core/providers/ngraph/ngraph_execution_provider.h +++ b/onnxruntime/core/providers/ngraph/ngraph_execution_provider.h @@ -4,12 +4,13 @@ #pragma once #include "core/framework/execution_provider.h" +#include namespace ngraph { -namespace runtime { -class Backend; + namespace runtime { + class Backend; + } } -} // namespace ngraph namespace onnxruntime { @@ -35,4 +36,4 @@ class NGRAPHExecutionProvider : public IExecutionProvider { std::shared_ptr ng_backend_; }; -} // namespace onnxruntime +} diff --git a/onnxruntime/core/providers/nuphar/common/analysis/analysis.h b/onnxruntime/core/providers/nuphar/common/analysis/analysis.h new file mode 100644 index 0000000000000..9c4e771761814 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/analysis.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/common/common.h" +#include "core/common/common.h" +#include "core/graph/graph_viewer.h" +#include "core/providers/nuphar/common/nuphar_subgraph.h" + +namespace onnxruntime { +namespace nuphar { + +// abstract class for Analysis +template +class AnalysisBase { + public: + AnalysisBase() {} + + AnalysisBase(const std::string& name) + : name_(name) {} + + virtual ~AnalysisBase() = default; + + virtual void Evaluate(INPUT_TYPE) = 0; + + const std::string& Name() const { + return name_; + } + + protected: + const std::string name_{"Unknown"}; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(AnalysisBase); +}; + +using OrtAnalysis = AnalysisBase; +using NupharAnalysis = AnalysisBase; + +// Add Promote for OrtAnalysis and NupharAnalysis +DYNAMIC_PROMOTE(OrtAnalysis) +DYNAMIC_PROMOTE(NupharAnalysis) + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/graph_stats.h b/onnxruntime/core/providers/nuphar/common/analysis/graph_stats.h new file mode 100644 index 0000000000000..f86a28695b9b9 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/graph_stats.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/common/common.h" +#include "core/graph/graph_viewer.h" +#include "core/providers/nuphar/common/analysis/analysis.h" + +#include "core/providers/nuphar/common/nuphar_subgraph.h" +// Base class of GraphStatsBase +// GraphStatsBase holds analysis results from a graph +// GraphStatsBase can hold multiple analyses + +namespace onnxruntime { +namespace nuphar { + +template +class GraphStatsBase { + public: + GraphStatsBase(const std::string& name) + : name_(name) {} + + GraphStatsBase() {} + + virtual ~GraphStatsBase() = default; + + // Evaluate all passes + virtual void Evaluate(INPUT_TYPE graph) { + for (auto& pass : passes_) { + pass->Evaluate(graph); + } + } + + // Set passes externally + void SetAllPasses(const std::vector>>& passes) { + passes_.clear(); + for (auto& pass : passes) { + passes_.push_back(pass); + } + } + + // Set existed evaluated passes externally + void SetAllExistedEvaluatedPasses( + const std::vector>>& passes) { + existed_eval_passes_.clear(); + for (auto& pass : passes) { + existed_eval_passes_.push_back(pass); + } + } + + const std::string& Name() const { + return name_; + } + + protected: + const std::string name_{"Unknown"}; + + std::vector>> passes_; + + private: + // existed eval passes not requiring evaluation + std::vector>> existed_eval_passes_; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphStatsBase); +}; + +using OrtGraphStats = GraphStatsBase; +using NupharSubgraphUnitStats = GraphStatsBase; + +// Add Promote for OrtGraphStats and NupharSubgraphUnitStats +DYNAMIC_PROMOTE(OrtGraphStats) +DYNAMIC_PROMOTE(NupharSubgraphUnitStats) + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.cc b/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.cc new file mode 100644 index 0000000000000..1283f9cd409a8 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/analysis/output_alias_analysis.h" + +#include "core/codegen/common/common.h" + +namespace onnxruntime { +namespace nuphar { + +void OutputAliasAnalysis::Traverse(const std::vector& nodes, + const std::set& graph_inputs, + const std::set& graph_outputs) { + for (auto& node : nodes) { + if (node->NodeType() == Node::Type::Fused) { + // unboxing of fused node + const auto& func_body = GraphViewer(node->GetFunctionBody()->Body()); + Traverse(ConvertGraphNodesToNodePtrs(func_body.Nodes()), graph_inputs, graph_outputs); + } else { + // TODO: change identity to other alias + bool is_identity = (node->OpType() == "Identity"); + node->ForEachWithIndex( + node->OutputDefs(), + [&](const NodeArg& def, size_t) { + if (graph_outputs.count(def.Name()) > 0) { + NodeKey key = GetKey(node); + output_nodes_.insert(key); + if (is_identity) { + auto input_def = node->InputDefs()[0]; + // regard as aliased if input_def is not graph input + // otherwise, we still generate Identity ops in TVM + // TODO: remove once we have a better solution for alias optimization + if (graph_inputs.count(input_def->Name()) == 0) { + alias_use_defs_.insert(std::make_pair(key, input_def)); + NodeKey input_key = GetKey(input_def); + output_nodes_.insert(input_key); + } + } + } + return Status::OK(); + }); + } + } +} + +// TODO: please reimplement output alias using the right algorithm. +// Currently we only copy it from old graph_stats, which is still wrong one +void OutputAliasAnalysis::Evaluate(const onnxruntime::nuphar::NupharSubgraphUnit& graph) { + if (graph.IsSingleNode()) { + const Node* node = graph.nodes.front(); + auto subgraph = GetSubgraph(*node); + + if (nullptr != subgraph) { + std::set graph_inputs; + std::set graph_outputs; + const auto& graph_viewer = GraphViewer(*subgraph); + for (const auto* def : graph_viewer.GetInputs()) { + if (nullptr != def) { + graph_inputs.insert(def->Name()); + } + } + for (const auto* def : graph_viewer.GetOutputs()) { + if (nullptr != def) { + graph_outputs.insert(def->Name()); + } + } + Traverse(ConvertGraphNodesToNodePtrs(graph_viewer.Nodes()), graph_inputs, graph_outputs); + } else { + NodeKey key = GetKey(node); + output_nodes_.insert(key); + } + } else { + // outputs names + std::set graph_inputs; + std::set graph_outputs; + for (const auto* def : graph.inputs) { + if (nullptr != def) { + graph_inputs.insert(def->Name()); + } + } + for (const auto* def : graph.outputs) { + if (nullptr != def) { + graph_outputs.insert(def->Name()); + } + } + Traverse(graph.nodes, graph_inputs, graph_outputs); + } +} + +bool OutputAliasAnalysis::IsOutputNode(const onnxruntime::Node* node) const { + return output_nodes_.count(GetKey(node)) != 0; +} + +bool OutputAliasAnalysis::IsOutputAlias(const onnxruntime::Node* node) const { + auto key = GetKey(node); + return alias_use_defs_.count(key) != 0; +} + +const onnxruntime::NodeArg* +OutputAliasAnalysis::SourceDefOfOutputAlias(const onnxruntime::NodeArg* node) const { + auto iter = alias_use_defs_.find(GetKey(node)); + if (iter != alias_use_defs_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.h b/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.h new file mode 100644 index 0000000000000..57a86205c5041 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/output_alias_analysis.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/graph/graph.h" +#include "core/providers/nuphar/common/analysis/analysis.h" + +namespace onnxruntime { +namespace nuphar { + +class OutputAliasAnalysis : public NupharAnalysis { + public: + OutputAliasAnalysis() + : NupharAnalysis("OutputAliasAnalysis") {} + + ~OutputAliasAnalysis() = default; + + void Evaluate(const onnxruntime::nuphar::NupharSubgraphUnit& graph) override; + + bool IsOutputNode(const onnxruntime::Node* node) const; + + bool IsOutputAlias(const onnxruntime::Node* node) const; + + const onnxruntime::NodeArg* SourceDefOfOutputAlias(const onnxruntime::NodeArg* node) const; + + private: + // a set for output nodes + std::set output_nodes_; + // a map from an output alias to its input + std::map alias_use_defs_; + + void Traverse(const std::vector& nodes, + const std::set& graph_inputs, + const std::set& graph_outputs); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OutputAliasAnalysis); +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/shape_expr.h b/onnxruntime/core/providers/nuphar/common/analysis/shape_expr.h new file mode 100644 index 0000000000000..76495c77df662 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/shape_expr.h @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" + +// TODO retire this file + +namespace onnxruntime { + +// A mini IR layer for shape inference +// Currently just use tvm::Expr but can be replaced by others later +// Following features are needed: +// 1. represent symbolic int +// 2. represent +-*/ +// 3. check if two DimExpr is the same +// 4. simplify if needed +// For now only symbolic int is supported +class SimpleDimExpr { + public: + SimpleDimExpr() : has_value_(false) {} + SimpleDimExpr(int64_t i) : value_(i), has_value_(true) {} + SimpleDimExpr(const std::string& sym) : symbol_(sym), has_value_(false) {} + bool IsConst() const { return has_value_; } + bool IsOne() const { return has_value_ && value_ == 1; } + bool operator==(const SimpleDimExpr& expr) const { + if (has_value_ != expr.has_value_) + return false; + + if (has_value_) + return value_ == expr.value_; + else + return symbol_ == expr.symbol_; + } + + bool operator!=(const SimpleDimExpr& expr) const { + return !(*this == expr); + } + + SimpleDimExpr operator+(const SimpleDimExpr& other) const { + ORT_ENFORCE(has_value_ && other.has_value_); + return SimpleDimExpr(value_ + other.value_); + } + + SimpleDimExpr operator-(const SimpleDimExpr& other) const { + ORT_ENFORCE(has_value_ && other.has_value_); + return SimpleDimExpr(value_ - other.value_); + } + + SimpleDimExpr operator*(const SimpleDimExpr& other) const { + if (has_value_ && other.has_value_) + return SimpleDimExpr(value_ * other.value_); + else if (IsOne()) + return other; + else if (other.IsOne()) + return *this; + else + ORT_ENFORCE(false, "unsupported symbolic dim computation"); + } + + SimpleDimExpr operator/(const SimpleDimExpr& other) const { + if (has_value_ && other.has_value_) + return SimpleDimExpr(value_ / other.value_); + else if (other.IsOne()) + return *this; + else + ORT_ENFORCE(false, "unsupported symbolic dim computation"); + } + + SimpleDimExpr operator%(const SimpleDimExpr& other) const { + ORT_ENFORCE(has_value_ && other.has_value_); + return SimpleDimExpr(value_ % other.value_); + } + + int64_t Value() const { + ORT_ENFORCE(IsConst()); + return value_; + } + + const std::string& Symbol() const { + ORT_ENFORCE(!IsConst()); + return symbol_; + } + + std::string ToString() const { + if (has_value_) + return std::to_string(value_); + else + return symbol_; + } + + private: + std::string symbol_; + int64_t value_; + bool has_value_; +}; + +template +class ShapeExprT { + public: + ShapeExprT() = default; + ShapeExprT(const ShapeExprT& expr) = default; + ShapeExprT(ShapeExprT&& expr) = default; + ShapeExprT(size_t size) { dims_.resize(size); } + ShapeExprT(const std::vector& dims) : dims_(dims) {} + ShapeExprT(const std::vector& dims) { + for (auto dim : dims) + dims_.push_back(DimT(dim)); + } + + size_t Rank() const { + return dims_.size(); + } + + int64_t TotalKnown() const { + if (dims_.size() == 0) + return 1; + int64_t total = 1; + for (size_t i = 0; i < dims_.size(); ++i) { + if (dims_[i].IsConst()) + total = total * dims_[i].Value(); + } + return total; + } + + size_t KnownFromDimension() const { + size_t min_index = dims_.size(); + for (int i = static_cast(dims_.size() - 1); i >= 0; i--) { + if (!dims_[i].IsConst()) + break; + min_index = static_cast(i); + } + return min_index; + } + + std::vector TailedKnown() const { + std::vector result; + + for (size_t i = KnownFromDimension(); i < Rank(); ++i) { + result.push_back(dims_[i].Value()); + } + return result; + } + + int64_t TotalTailedKnown() const { + int64_t result = 1; + for (size_t i = KnownFromDimension(); i < Rank(); ++i) { + result *= dims_[i].Value(); + } + return result; + } + + /** + Return the total number of elements up to the specified dimension. + @param dim Return size up to this dimension. Value must be >= 0 and < this.Size(). + */ + DimT SizeToDimension(size_t dim) const { + DimT total(1); + for (size_t i = 0; i < std::min(dim, dims_.size()); ++i) + total = total * dims_[i]; + return total; + } + + /** + Return the total number of elements from the specified dimension to the end of the tensor shape. + @param dim Return size up to this dimension. 0 <= dimension < this.Size(). + */ + DimT SizeFromDimension(size_t dim) const { + DimT total(1); + for (size_t i = dim; i < dims_.size(); ++i) + total = total * dims_[i]; + return total; + } + + bool IsConst() const { + return std::all_of(dims_.begin(), dims_.end(), [](const DimT& dim) { return dim.IsConst(); }); + } + + bool operator==(const ShapeExprT& shape) const { + if (Rank() != shape.Rank()) + return false; + + for (size_t dim = 0; dim < Rank(); ++dim) { + if (dims_[dim] != (shape.dims_[dim])) + return false; + } + return true; + } + + const ShapeExprT& operator=(const ShapeExprT& shape) { + dims_ = shape.dims_; + return *this; + } + + const DimT& at(size_t dim) const { + ORT_ENFORCE(dim < Rank()); + return dims_[dim]; + } + + DimT& at(size_t dim) { + ORT_ENFORCE(dim < Rank()); + return dims_[dim]; + } + + const DimT& operator[](size_t dim) const { + ORT_ENFORCE(dim < Rank()); + return dims_[dim]; + } + + DimT& operator[](size_t dim) { + ORT_ENFORCE(dim < Rank()); + return dims_[dim]; + } + + const std::vector Value() const { + ORT_ENFORCE(IsConst()); + std::vector result; + for (size_t i = 0; i < Rank(); ++i) { + result.push_back(dims_[i].Value()); + } + return result; + } + + std::string ToString() const { + std::ostringstream oss; + oss << "("; + for (size_t i = 0; i < Rank(); ++i) { + if (i > 0) + oss << ", "; + oss << dims_[i].ToString(); + } + oss << ")"; + return oss.str(); + } + + private: + std::vector dims_; +}; + +typedef SimpleDimExpr DimExpr; +typedef ShapeExprT ShapeExpr; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.cc b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.cc new file mode 100644 index 0000000000000..836c3b98291c8 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.cc @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" + +#include "core/providers/nuphar/common/analysis/output_alias_analysis.h" +#include "core/providers/nuphar/common/analysis/use_count_analysis.h" + +namespace onnxruntime { +namespace nuphar { + +// CodeGenUnitStats has two analysis passes +// The first pass, offset as 0, is UseCountAnalysis +// The second pass, offset as 1, is OutputAliasAnalysis +constexpr int UseCountAnalysisOffset = 0; +constexpr int OutputAliasAnalysisOffset = 1; + +// True reuse count for cheap Op +constexpr int CheapNodeTrueReuseCount = 2; + +// Constructor +CodeGenUnitStats::CodeGenUnitStats( + const std::shared_ptr& shape_infernece) + : NupharSubgraphUnitStats("CodeGenUnitStats") { + auto use_count_pass = std::make_shared(shape_infernece); + passes_.push_back(use_count_pass); + + auto output_alias_pass = std::make_shared(); + passes_.push_back(output_alias_pass); +} + +int CodeGenUnitStats::NodeUseCount(const onnxruntime::Node* node) const { + ORT_ENFORCE(passes_.size() > UseCountAnalysisOffset); + return Promote(passes_[UseCountAnalysisOffset])->NodeUseCount(node); +} + +bool CodeGenUnitStats::IsCheapNodeReuse(const onnxruntime::Node* node) const { + ORT_ENFORCE(passes_.size() > UseCountAnalysisOffset); + // Define cheap nodes include Add / Sub / Mul + if (node->OpType() == "Add" || node->OpType() == "Sub" || node->OpType() == "Mul") + return Promote(passes_[UseCountAnalysisOffset])->NodeUseCount(node) > CheapNodeTrueReuseCount; + + // Otherwise return true and use count is determined by NodeUseCount + return true; +} + +bool CodeGenUnitStats::IsOutputNode(const onnxruntime::Node* node) const { + ORT_ENFORCE(passes_.size() > OutputAliasAnalysisOffset); + return Promote(passes_[OutputAliasAnalysisOffset])->IsOutputNode(node); +} + +bool CodeGenUnitStats::IsOutputAlias(const onnxruntime::Node* node) const { + ORT_ENFORCE(passes_.size() > OutputAliasAnalysisOffset); + return Promote(passes_[OutputAliasAnalysisOffset])->IsOutputAlias(node); +} + +const onnxruntime::NodeArg* CodeGenUnitStats::SourceDefOfOutputAlias(const onnxruntime::NodeArg* node) const { + ORT_ENFORCE(passes_.size() > OutputAliasAnalysisOffset); + return Promote(passes_[OutputAliasAnalysisOffset])->SourceDefOfOutputAlias(node); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.h b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.h new file mode 100644 index 0000000000000..69b135f73e34d --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_codegen_stats.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/common/common.h" +#include "core/providers/nuphar/common/analysis/graph_stats.h" +#include "core/providers/nuphar/common/analysis/use_count_analysis.h" +#include "core/providers/nuphar/common/nuphar_subgraph.h" + +namespace onnxruntime { +namespace nuphar { + +class CodeGenUnitStats : public NupharSubgraphUnitStats { + public: + CodeGenUnitStats(const std::shared_ptr& shape_infernece); + + ~CodeGenUnitStats() = default; + + int NodeUseCount(const onnxruntime::Node* node) const; + + bool IsCheapNodeReuse(const onnxruntime::Node* node) const; + + bool IsOutputNode(const onnxruntime::Node* node) const; + + bool IsOutputAlias(const onnxruntime::Node* node) const; + + const onnxruntime::NodeArg* SourceDefOfOutputAlias(const onnxruntime::NodeArg* node) const; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CodeGenUnitStats); +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.cc b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.cc new file mode 100644 index 0000000000000..b35b456bd6bbd --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.cc @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/analysis/subgraph_partition_stats.h" + +#include "core/providers/nuphar/common/analysis/use_count_analysis.h" + +namespace onnxruntime { +namespace nuphar { + +// TODO: Add memory analysis +// SubgraphPartitionStats has one analysis pass +// The first pass, offset as 0, is UseCountAnalysis +constexpr int UseCountAnalysisOffset = 0; + +void SubgraphPartitionStats::SetShapeInference( + const std::shared_ptr& shape_infernece) { + passes_.clear(); + passes_.emplace_back(std::make_shared(shape_infernece)); +} + +int SubgraphPartitionStats::NodeUseCount(const onnxruntime::Node* node) const { + ORT_ENFORCE(passes_.size() > UseCountAnalysisOffset); + return Promote(passes_[UseCountAnalysisOffset])->NodeUseCount(node); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.h b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.h new file mode 100644 index 0000000000000..afbcf0d2a3886 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/subgraph_partition_stats.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/providers/nuphar/common/analysis/graph_stats.h" +#include "core/providers/nuphar/compiler/traverse_shape_infer.h" + +namespace onnxruntime { +namespace nuphar { + +// TODO: rename class name to more target-specific in the tvm refactoring +// Maybe GraphPartitionStatsX86 +class SubgraphPartitionStats : public OrtGraphStats { + public: + SubgraphPartitionStats() + : OrtGraphStats("SubgraphPartitionStats") {} + + ~SubgraphPartitionStats() = default; + + void SetShapeInference(const std::shared_ptr& shape_infernece); + + int NodeUseCount(const onnxruntime::Node* node) const; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SubgraphPartitionStats); +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.cc b/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.cc new file mode 100644 index 0000000000000..c754db4377fba --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.cc @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/analysis/use_count_analysis.h" + +#include "core/codegen/common/common.h" +#include "core/graph/function.h" + +namespace onnxruntime { +namespace nuphar { + +constexpr int PRESET_USE_COUNT_FOR_UNKNOWN = 10; +constexpr int PRESET_USE_COUNT_FOR_SOFTMAX = 3; + +static void CountGemmOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts); + +static void CountMatMulOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts); + +static void CountRecurrentOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts); + +static void CountMatrixArgs(const onnxruntime::NodeArg* A, + const onnxruntime::NodeArg* B, + const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts); + +static void CountNodeArg(const onnxruntime::NodeArg* input_def, + const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::unordered_map& node_use_counts, + int use_cnt); + +static bool IsMatMulOp(const std::string& op) { + return op == "MatMul" || op == "MatMulInteger" || op == "MatMulInteger16"; +} + +void CountGemmOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts) { + ORT_ENFORCE(node.OpType() == "Gemm"); + + auto inputs = node.InputDefs(); + CountMatrixArgs(inputs[0], inputs[1], node, graph_inputs, shape_func, node_use_counts); + // C's use cnt is fixed. + CountNodeArg(inputs[2], node, graph_inputs, node_use_counts, 1); +} + +void CountMatMulOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts) { + ORT_ENFORCE(IsMatMulOp(node.OpType())); + auto inputs = node.InputDefs(); + CountMatrixArgs(inputs[0], inputs[1], node, graph_inputs, shape_func, node_use_counts); +} + +void CountRecurrentOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function, + std::unordered_map& node_use_counts) { + int use_count = PRESET_USE_COUNT_FOR_UNKNOWN; + + node.ForEachWithIndex( + node.InputDefs(), + [&node, &graph_inputs, &node_use_counts, &use_count](const NodeArg& def, size_t) { + CountNodeArg(&def, node, graph_inputs, node_use_counts, use_count); + return Status::OK(); + }); +} + +static bool IsSoftmaxOp(const std::string& op) { + return op == "Softmax" || op == "LogSoftmax"; +} + +void CountSoftmaxOp(const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function, + std::unordered_map& node_use_counts) { + // Use preset use count for Softmax/LogSoftmax input + int use_count = PRESET_USE_COUNT_FOR_SOFTMAX; + + node.ForEachWithIndex( + node.InputDefs(), + [&node, &graph_inputs, &node_use_counts, &use_count](const NodeArg& def, size_t) { + CountNodeArg(&def, node, graph_inputs, node_use_counts, use_count); + return Status::OK(); + }); +} + +void CountMatrixArgs(const onnxruntime::NodeArg* A, + const onnxruntime::NodeArg* B, + const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::function shape_func, + std::unordered_map& node_use_counts) { + int use_cnt = PRESET_USE_COUNT_FOR_UNKNOWN; + const ShapeExpr* a_shape = shape_func(A); + if (nullptr != a_shape) { + // B's use cnt is based on the rows of A + // skip symbolic dimensions for Sequence and batch + auto a_cols = (a_shape->Rank() > 0 && a_shape->at(a_shape->Rank() - 1).IsConst()) ? a_shape->at(a_shape->Rank() - 1).Value() : 1; + use_cnt = a_shape->TotalTailedKnown() / a_cols; + } + CountNodeArg(B, node, graph_inputs, node_use_counts, use_cnt); + + // reset use_cnt + use_cnt = PRESET_USE_COUNT_FOR_UNKNOWN; + const ShapeExpr* b_shape = shape_func(B); + if (nullptr != b_shape) { + const DimExpr& dim = b_shape->Rank() > 1 ? b_shape->at(b_shape->Rank() - 1) : DimExpr(1); + // A's use cnt is based on the cols of B. If B is 1-D, use cnt is 1 + if (dim.IsConst()) + use_cnt = dim.Value(); + } + + CountNodeArg(A, node, graph_inputs, node_use_counts, use_cnt); +} + +void CountNodeArg(const onnxruntime::NodeArg* input_def, + const onnxruntime::Node& node, + const std::vector& graph_inputs, + std::unordered_map& node_use_counts, + int use_cnt) { + // Skip graph's input args nodes + if (std::find(graph_inputs.begin(), graph_inputs.end(), input_def) != graph_inputs.end()) + return; + + const Node* input_node = GetInputNode(node, input_def); + + if (nullptr != input_node) { + node_use_counts[GetKey(input_node)] += use_cnt; + } +} + +InternalUseCountAnalysis::InternalUseCountAnalysis(const std::shared_ptr& shape_inference) { + shape_func_ = [&shape_inference](const onnxruntime::NodeArg* X) { + return shape_inference->Lookup(X); + }; +} + +void InternalUseCountAnalysis::Traverse( + const std::vector& nodes, + const std::vector& graph_inputs, + const std::vector& graph_outputs) { + for (auto& node : nodes) { + auto op_type = node->OpType(); + if (op_type == "Gemm") { + CountGemmOp(*node, graph_inputs, shape_func_, node_use_counts_); + } else if (IsMatMulOp(op_type)) { + CountMatMulOp(*node, graph_inputs, shape_func_, node_use_counts_); + } else if (op_type == "Scan") { + auto subgraph = node->GetGraphAttribute("body"); + Evaluate(GraphViewer(*subgraph)); + int use_count = PRESET_USE_COUNT_FOR_UNKNOWN; + node->ForEachWithIndex( + node->InputDefs(), + [this, &node, &graph_inputs, &use_count](const NodeArg& def, size_t) { + CountNodeArg(&def, *node, graph_inputs, node_use_counts_, use_count); + return Status::OK(); + }); + } else if (IsRecurrentNode(*node)) { + CountRecurrentOp(*node, graph_inputs, shape_func_, node_use_counts_); + } else if (node->NodeType() == Node::Type::Fused) { + // note: when unboxing subgraph in fused node, use outermost graph input/output + const auto& func_body = GraphViewer(node->GetFunctionBody()->Body()); + Traverse(ConvertGraphNodesToNodePtrs(func_body.Nodes()), graph_inputs, graph_outputs); + } else if (IsSoftmaxOp(op_type)) { + CountSoftmaxOp(*node, graph_inputs, shape_func_, node_use_counts_); + } else { + int use_count = 1; + node->ForEachWithIndex( + node->InputDefs(), + [this, &node, &graph_inputs, &use_count](const NodeArg& def, size_t) { + CountNodeArg(&def, *node, graph_inputs, node_use_counts_, use_count); + return Status::OK(); + }); + } + + NodeKey key = GetKey(node); + // For any output_def of the node that is part of graph's outputs but not from graph.Nodes(), + // we need to increase the node's use cnt accordingly. Otherwise, we would lose those uses. + node->ForEachWithIndex( + node->OutputDefs(), + [this, &graph_outputs, &key](const NodeArg& def, size_t) { + if (std::find(graph_outputs.begin(), graph_outputs.end(), &def) != graph_outputs.end()) { + node_use_counts_[key]++; + } + return Status::OK(); + }); + } +} + +void InternalUseCountAnalysis::Evaluate(const onnxruntime::GraphViewer& graph) { + const auto& graph_inputs = graph.GetInputs(); + const auto& graph_outputs = graph.GetOutputs(); + Traverse(ConvertGraphNodesToNodePtrs(graph.Nodes()), graph_inputs, graph_outputs); +} + +void InternalUseCountAnalysis::Evaluate(const onnxruntime::nuphar::NupharSubgraphUnit& graph) { + const auto& graph_inputs = graph.inputs; + const auto& graph_outputs = graph.outputs; + Traverse(graph.nodes, graph_inputs, graph_outputs); +} + +void InternalUseCountAnalysis::IncrementCount(const onnxruntime::NodeArg* def) { + node_use_counts_[GetKey(def)]++; +} + +int InternalUseCountAnalysis::NodeUseCount(const onnxruntime::Node* node) const { + auto node_iter = node_use_counts_.find(GetKey(node)); + if (node_iter != node_use_counts_.end()) { + return node_iter->second; + } else { + return 0; + } +} + +OrtUseCountAnalysis::OrtUseCountAnalysis(const std::shared_ptr& shape_inference) + : OrtAnalysis("OrtUseCountAnalysis") { + internal_analysis_ = std::make_unique(shape_inference); +} + +void OrtUseCountAnalysis::Evaluate(const onnxruntime::GraphViewer& graph) { + internal_analysis_->Evaluate(graph); +} + +void OrtUseCountAnalysis::IncrementCount(const onnxruntime::NodeArg* def) { + internal_analysis_->IncrementCount(def); +} + +int OrtUseCountAnalysis::NodeUseCount(const onnxruntime::Node* node) const { + return internal_analysis_->NodeUseCount(node); +} + +NupharUseCountAnalysis::NupharUseCountAnalysis(const std::shared_ptr& shape_inference) + : NupharAnalysis("NupharUseCountAnalysis") { + internal_analysis_ = std::make_unique(shape_inference); +} + +void NupharUseCountAnalysis::Evaluate(const onnxruntime::nuphar::NupharSubgraphUnit& graph) { + internal_analysis_->Evaluate(graph); +} + +void NupharUseCountAnalysis::IncrementCount(const onnxruntime::NodeArg* def) { + internal_analysis_->IncrementCount(def); +} + +int NupharUseCountAnalysis::NodeUseCount(const onnxruntime::Node* node) const { + return internal_analysis_->NodeUseCount(node); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.h b/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.h new file mode 100644 index 0000000000000..4d79af579b609 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/analysis/use_count_analysis.h @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/common/common.h" +#include "core/providers/nuphar/common/analysis/analysis.h" +#include "core/providers/nuphar/common/analysis/shape_expr.h" +#include "core/providers/nuphar/compiler/traverse_shape_infer.h" +#include "core/graph/graph.h" + +#include +#include + +// TODO change namespace from codegen to nuphar + +namespace onnxruntime { +namespace nuphar { + +class InternalUseCountAnalysis { + public: + InternalUseCountAnalysis(const std::shared_ptr& shape_inference); + + ~InternalUseCountAnalysis() = default; + + void Evaluate(const onnxruntime::GraphViewer& graph); + + void Evaluate(const NupharSubgraphUnit& graph); + + void IncrementCount(const onnxruntime::NodeArg* arg); + + int NodeUseCount(const onnxruntime::Node* node) const; + + private: + void Traverse(const std::vector& nodes, + const std::vector& graph_inputs, + const std::vector& graph_outputs); + + std::unordered_map node_use_counts_; + std::function shape_func_; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InternalUseCountAnalysis); +}; + +// TODO analysis move to namespace nuphar + +class OrtUseCountAnalysis : public OrtAnalysis { + public: + OrtUseCountAnalysis(const std::shared_ptr& shape_inference); + ~OrtUseCountAnalysis() = default; + + void Evaluate(const onnxruntime::GraphViewer& graph) override; + + void IncrementCount(const onnxruntime::NodeArg* arg); + + int NodeUseCount(const onnxruntime::Node* node) const; + + private: + std::unique_ptr internal_analysis_; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OrtUseCountAnalysis); +}; + +class NupharUseCountAnalysis : public NupharAnalysis { + public: + NupharUseCountAnalysis(const std::shared_ptr& shape_inference); + + ~NupharUseCountAnalysis() = default; + + void Evaluate(const onnxruntime::nuphar::NupharSubgraphUnit& graph) override; + + void IncrementCount(const onnxruntime::NodeArg* arg); + + int NodeUseCount(const onnxruntime::Node* node) const; + + private: + std::unique_ptr internal_analysis_; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(NupharUseCountAnalysis); +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc new file mode 100644 index 0000000000000..1e3981004ba9b --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/nuphar_settings.h" + +#include "core/codegen/common/common.h" +#include "core/codegen/common/utils.h" +#include "core/common/logging/logging.h" +#include "core/providers/nuphar/nuphar_execution_provider.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +static const std::unordered_set valid_keys = { + codegen::CodeGenSettings::kDumpAllOptions, + codegen::CodeGenSettings::kCodeGenDumpModule, + codegen::CodeGenSettings::kCodeGenDumpLower, + codegen::CodeGenSettings::kCodeGenDumpSchedule, + kNupharFastMath, + kNupharFastActivation, + kNupharDumpFusedNodes, + kNupharDumpPartition, + kNupharIMatMulForceMkl, + kNupharMatmulExec, + kNupharCachePath, + kNupharCacheVersion, + kNupharCacheSoName, + kNupharCacheModelChecksum, + kNupharCacheForceNoJIT, + kNupharCodeGenTarget}; + +void SetDefaultOptions(std::map& options) { + // create two temporary strings to get rid of the odr-use issue introduced + // The issue would trigger missing definition errors for static constexpr members + // at link time. + std::string fast_math_opt(kNupharFastMath); + std::string select_fast_math(kNupharFastMath_ShortPolynormial); + options.insert(std::make_pair(fast_math_opt, select_fast_math)); + + std::string fast_act_opt(kNupharFastActivation); + std::string select_fast_act(kNupharActivations_DeepCpu); + options.insert(std::make_pair(fast_act_opt, select_fast_act)); + + // set jit cache so name + std::string cache_so_name_opt(kNupharCacheSoName); + std::string cache_so_name_default(kNupharCacheSoName_Default); + options.insert(std::make_pair(cache_so_name_opt, cache_so_name_default)); +} + +void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info) { + std::map options; + SetDefaultOptions(options); + + std::unordered_set required_options; + if (!info.settings.empty()) { + const std::string& str = info.settings; + + // tokenize settings + std::regex reg("\\s*,\\s*"); + std::sregex_token_iterator iter(str.begin(), str.end(), reg, -1); + std::sregex_token_iterator iter_end; + std::vector pairs(iter, iter_end); + + ORT_ENFORCE(pairs.size() > 0); + for (const auto& pair : pairs) { + auto pos_colon = pair.find(':'); + ORT_ENFORCE(pos_colon != std::string::npos, "Invalid key value pair."); + std::string key = pair.substr(0, pos_colon); + std::string value = pair.substr(pos_colon + 1); + + // trim leading and trailing spaces from key/value + auto trim = [](const std::string& str) -> std::string { + const std::string WHITESPACE = " \n\r\t\f\v"; + size_t start = str.find_first_not_of(WHITESPACE); + if (start == std::string::npos) { + return ""; + } else { + size_t end = str.find_last_not_of(WHITESPACE); + ORT_ENFORCE(end != std::string::npos); + return str.substr(start, end + 1); + } + }; + key = trim(key); + value = trim(value); + + if (valid_keys.count(key) == 0) { + ORT_NOT_IMPLEMENTED("NupharCodeGenSettings: unknown option (", key, ")"); + } + required_options.insert(key); + options.insert(std::make_pair(key, value)); + } + } + +#ifndef GOLDEN_BUILD + // environment variables override existing settings + for (const auto& key : valid_keys) { + std::string env_key; + // env var is always upper case + std::transform(key.begin(), key.end(), std::back_inserter(env_key), (int (*)(int))std::toupper); + if (IsEnvVarDefined(env_key.c_str())) { + // value is case-sensitive + auto value = std::string(GetEnv(env_key.c_str()).get()); + + if (required_options.count(key) > 0 && options.at(key) != value) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) + << "NupharCodeGenSettings: option(" << key + << ") from environment variable is ignored because of existing required option value: " + << options.at(key); + } else { + options[key] = value; + } + } + } +#endif + + codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); + settings.Clear(); // remove previous settings and start from scratch + + settings.InsertOptions(options); + + if (settings.HasOption(codegen::CodeGenSettings::kDumpAllOptions)) { + settings.DumpOptions(); + } +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.h b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h new file mode 100644 index 0000000000000..91d2f03a4b583 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/common/settings.h" + +namespace onnxruntime { + +// forward declaration +struct NupharExecutionProviderInfo; + +namespace nuphar { +constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition"; +constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes"; +constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec"; +constexpr static const char* kNupharCachePath = "nuphar_cache_path"; +constexpr static const char* kNupharCacheVersion = "nuphar_cache_version"; +constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name"; +constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum"; +constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit"; +// force to use IMatMulExternMKL/IMatMul16ExternMKL +constexpr static const char* kNupharIMatMulForceMkl = "nuphar_imatmul_force_mkl"; + +constexpr static const char* kNupharMatMulExec_ExternCpu = "extern_cpu"; + +constexpr static const char* kNupharFastMath = "nuphar_fast_math"; // fast math +constexpr static const char* kNupharFastMath_Polynormial = "polynormial_math"; // generic polynormial fast math for exp and log +constexpr static const char* kNupharFastMath_ShortPolynormial = "short_polynormial_math"; // generic shorter polynormial fast math for exp and log + +constexpr static const char* kNupharFastActivation = "nuphar_fast_activation"; // fast activation +constexpr static const char* kNupharActivations_DeepCpu = "deep_cpu_activation"; + +// Option to control nuphar code generation target (avx2 or avx512) +constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target"; + +// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/ +// 1. MAJOR version when you make incompatible changes that old cache files no longer work, +// 2. MINOR version when you add functionality in a backwards - compatible manner, and +// 3. PATCH version when you make backwards - compatible bug fixes. +// NOTE this version needs to be updated when generated code may change +constexpr static const char* kNupharCacheVersion_Current = "1.0.0"; + +constexpr static const char* kNupharCacheSoName_Default = "jit.so"; + +void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_subgraph.h b/onnxruntime/core/providers/nuphar/common/nuphar_subgraph.h new file mode 100644 index 0000000000000..06e105150ad54 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/nuphar_subgraph.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/tensor.h" +#include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" + +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +using FindInitializerFunc = std::function; + +struct OrtSubgraphAllocationInfo { + std::unordered_map internal_allocator_offset; + std::unordered_map inputs; + std::unordered_map outputs; + int offset_count; + + OrtSubgraphAllocationInfo(const Node& node) : offset_count(0) { + int input_counter = 0; + int output_counter = 0; + + node.ForEachDef( + [&input_counter, &output_counter, this](const NodeArg& def, bool is_input) { + const std::string& def_name = def.Name(); + if (is_input) { + if (inputs.count(def_name) == 0) { + inputs.emplace(def_name, input_counter); + } + input_counter++; + } else { + outputs.emplace(def_name, output_counter++); + } + }); + } + + int CreateOrGetInternalAllocatorOffset(const std::string& def_name) { + if (internal_allocator_offset.count(def_name) > 0) { + return internal_allocator_offset.at(def_name); + } + internal_allocator_offset.insert(std::make_pair(def_name, offset_count)); + return offset_count++; + } +}; + +enum class NodeArgTileAttribute : int { + None = 0, + Forward = 1, + Backward = 2, + NoMerger = 3, +}; + +// NupharSubgraphUnit is a data struct under Ort Subgraph. +// It is a customized data struct in nuphar +// to enable concurrent function codegen within a Ort Kernel (which maps to an Ort Subgraph) +struct NupharSubgraphUnit { + NupharSubgraphUnit() { + id_ = counter++; + } + + std::vector nodes; + + // inputs include each input of this NupharSubgraphUnit (input of Partition AND this NupharSubgraphUnit at the same time) + // it also includes initializers + std::vector inputs; + + // outputs include each output of this NupharSubgraphUnit and real_output (output of Partition AND this NupharSubgraphUnit at the same time) + std::vector outputs; + + // initializers include each intializer of this NupharSubgraphUnit + std::map initializers; + + // optional + std::vector input_attrs; + std::vector output_attrs; + + bool IsSingleNode() const { + return nodes.size() == 1; + } + + const std::string& Name() const { + return nodes.front()->Name(); + } + + std::string UniqueId() const { + return std::to_string(id_); + } + + public: + // counter used for subgraph id + // reset outside after cache generated + // to avoid same inference session continue + // increase the counter + thread_local static int64_t counter; + + private: + int64_t id_; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc new file mode 100644 index 0000000000000..ec6566b0f8ff4 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/nuphar_tvm_utils.h" + +#include "core/providers/nuphar/common/nuphar_subgraph.h" +#include "core/providers/nuphar/common/nuphar_settings.h" +#include "core/codegen/common/common.h" +#include "core/codegen/common/target_info.h" + +#include "core/common/logging/logging.h" +#include "core/platform/env.h" +#include "core/providers/common.h" +#include "gsl/gsl_util" +#include +#include +#include +#include +namespace fs = std::experimental::filesystem; + +namespace onnxruntime { +namespace nuphar { + +static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) { + codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); + + if (!settings.HasOption(kNupharCachePath)) + return false; + + std::string version; + if (settings.HasOption(kNupharCacheVersion)) { + version = settings.GetOptionValue(kNupharCacheVersion); + } else { + version = kNupharCacheVersion_Current; + } + + path = settings.GetOptionValue(kNupharCachePath); + if (!create && !fs::is_directory(path)) + return false; + + if (!fs::is_directory(path)) + if (!fs::create_directory(path)) { + throw std::runtime_error("Failed to create directory " + path.string()); + } + + path.append(version); + if (!create && !fs::is_directory(path)) + return false; + + if (!fs::is_directory(path)) + if (!fs::create_directory(path)) { + throw std::runtime_error("Failed to create directory " + path.string()); + } + + return true; +} + +static bool GetCacheSoFilePath(std::string& so_path) { + codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); + fs::path path; + if (!GetOrCreateTVMModuleCacheDirectory(path, /*create*/ false)) + return false; + + auto so_name = settings.GetOptionValue(kNupharCacheSoName); + path.append(so_name); + if (fs::is_regular_file(path)) { + so_path = path.string(); + return true; + } + return false; +} + +static void* GetFuncFromLibrary(const std::string& so_path, const std::string& func_name, bool throw_if_not_found = true) { + void* so_handle; + ORT_ENFORCE(Env::Default().LoadDynamicLibrary(so_path, &so_handle).IsOK()); + void* func = nullptr; + Status s = Env::Default().GetSymbolFromLibrary(so_handle, func_name, &func); + if (throw_if_not_found && !s.IsOK()) + ORT_ENFORCE(false, "Cannot find ", func_name, " in ", so_path); + return func; +} + +static bool disable_caching_due_to_checksum_failure = false; + +static bool VerifyTVMModuleChecksum(const std::string& so_path) { + static std::string last_so_path; + static bool last_checksum_validated = false; + static std::mutex checksum_mutex; + if (last_so_path != so_path) { + std::lock_guard lock(checksum_mutex); + if (last_so_path != so_path) { + disable_caching_due_to_checksum_failure = false; // reset disabled caching for a new file + last_so_path = so_path; + void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false); + if (f) { + typedef void (*GetChecksumFunc)(const char*&, size_t&); + GetChecksumFunc func = reinterpret_cast(f); + const char* model_checksum; + size_t model_checksum_len; + func(model_checksum, + model_checksum_len); + + codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance(); + // When checksum is expected by dll/so, user must set environment variable + // NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model. + // User may choose to run with base model or simplified mode and any match + // would be regarded as validated. + // Note that checksum validation here is not designed as a security measurement, + // so checksum compute is not done inside ORT. + last_checksum_validated = + setting.OptionMatches( + kNupharCacheModelChecksum, + std::string(model_checksum, model_checksum_len)); + + if (!last_checksum_validated) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT..."; + disable_caching_due_to_checksum_failure = true; + } + } else { + // do not validate checksum if dll didn't require it (usually during debugging) + // TODO: force checksum validation in final release + last_checksum_validated = true; + } + } + } + return last_checksum_validated; +} + +tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name) { + std::string so_path; + if (!GetCacheSoFilePath(so_path)) + return nullptr; + + if (!VerifyTVMModuleChecksum(so_path)) + return nullptr; + + tvm::runtime::Module module = tvm::runtime::Module::LoadFromFile(so_path); + tvm::runtime::PackedFunc func = module.GetFunction(func_name); + if (func == nullptr) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cannot find " << func_name << " in cache, using JIT..."; + } + return func; +} + +thread_local int saved_tvm_model_cnt = 0; + +void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module) { + fs::path path; + + if (disable_caching_due_to_checksum_failure) + return; + + static std::mutex save_cache_mutex; + static std::unordered_set existing_files; + std::lock_guard lock(save_cache_mutex); + if (existing_files.count(filename) == 0 && + GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) { + existing_files.insert(filename); + path.append("cached_" + std::to_string(saved_tvm_model_cnt++) + ".o"); + if (fs::exists(path)) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving..."; + return; + } + module->SaveToFile(path.string(), "o"); + } +} + +std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target) { + // in C, a function does not allow its name starting with a digit. + return NormalizeCppName("_" + subgraph.UniqueId() + " " + codegen_target.GetTargetName()); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h new file mode 100644 index 0000000000000..3c26a0c6f61f9 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +#include "core/graph/graph.h" + +namespace onnxruntime { +class CodeGenTarget; //forward + +namespace nuphar { + +struct NupharSubgraphUnit; //forward +// Helper functions to create or load from offline cached dll +// note after saving to obj file, we need to use tvm Python to create dll +// using script at onnxruntime/core/codegen/mti/scripts/create_shared.py +tvm::runtime::PackedFunc +LoadTVMPackedFuncFromCache(const std::string& func_name); +void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module); + +std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/utils.cc b/onnxruntime/core/providers/nuphar/common/utils.cc new file mode 100644 index 0000000000000..848e368a71da4 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/utils.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/common/utils.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace nuphar { + +bool NodeArgShapeUnknownOnAxis(const NodeArg* def, int64_t axis) { + auto shape = def->Shape(); + axis = HandleNegativeAxis(axis, shape->dim_size()); + ORT_ENFORCE(axis < shape->dim_size()); + auto dim = shape->dim(axis); + return dim.has_dim_param() || (!dim.has_dim_param() && !dim.has_dim_value()); +} + +bool HasUnknownShapeOnAxis(const ConstPointerContainer>& defs, int64_t axis) { + for (const NodeArg* def : defs) { + if (NodeArgShapeUnknownOnAxis(def, axis)) { + return true; + } + } + return false; +} + +bool HasUnknownShapeOnAxes(const NodeArg* def, std::vector& axes) { + for (auto axis : axes) { + if (NodeArgShapeUnknownOnAxis(def, axis)) { + return true; + } + } + return false; +} + +Status GetSliceAxesFromTensorProto(std::vector& axes, + const ONNX_NAMESPACE::TensorProto& axes_tp) { + size_t tp_sz_in_bytes; + ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(axes_tp, &tp_sz_in_bytes)); + OrtValue ort_value; + std::unique_ptr data(new char[tp_sz_in_bytes]); + +#define UNPACK_TENSOR(T) \ + T* p = reinterpret_cast(data.get()); \ + ORT_RETURN_IF_ERROR(utils::UnpackTensor( \ + axes_tp, \ + axes_tp.raw_data().size() ? axes_tp.raw_data().data() : nullptr, \ + axes_tp.raw_data().size(), \ + p, \ + tp_sz_in_bytes / sizeof(T))); \ + std::vector tmp_axes(p, p + tp_sz_in_bytes / sizeof(T)); + + switch (axes_tp.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + UNPACK_TENSOR(int32_t); + for (auto axis : tmp_axes) { + axes.push_back(static_cast(axis)); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + UNPACK_TENSOR(int64_t); + axes.insert(axes.end(), tmp_axes.begin(), tmp_axes.end()); + break; + } + default: + ORT_NOT_IMPLEMENTED("Unimplemented type: ", axes_tp.data_type()); + } + + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/common/utils.h b/onnxruntime/core/providers/nuphar/common/utils.h new file mode 100644 index 0000000000000..a2c1a702f606d --- /dev/null +++ b/onnxruntime/core/providers/nuphar/common/utils.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/graph/graph.h" + +// forward declaration +struct OrtAllocatorInfo; + +namespace onnxruntime { +namespace nuphar { + +bool NodeArgShapeUnknownOnAxis(const NodeArg* def, int64_t axis); + +bool HasUnknownShapeOnAxis(const ConstPointerContainer>& defs, int64_t axis); + +bool HasUnknownShapeOnAxes(const NodeArg* def, std::vector& axes); + +Status GetSliceAxesFromTensorProto(std::vector& axes, + const ONNX_NAMESPACE::TensorProto& axes_tp); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/codegen_manager.cc b/onnxruntime/core/providers/nuphar/compiler/codegen_manager.cc new file mode 100644 index 0000000000000..981e14d2bec5a --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/codegen_manager.cc @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/codegen_manager.h" + +#include "core/codegen/common/op_macro.h" +#include "core/codegen/passes/op_ir_creator/all_ops.h" +#include "core/codegen/passes/scheduler/all_schedules.h" +#include "core/codegen/passes/weight_layout/transpose_2d.h" +#include "core/codegen/passes/weight_layout/vertical_stripes_2d.h" +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" +#include "core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h" + +namespace onnxruntime { +namespace codegen { +// explicit instantiation +template class RegistryBase; +} // namespace codegen + +namespace nuphar { + +//// All Creator instance registration +// 1. Create Customized Op IR creator instances + +// BEGIN: NupharTVM X86 IR creator classes + +#define ADD_OP_ITEM(name) \ + op_ir_registry->Register(std::move(std::make_unique())); + +#define REDUCE_V_OP(name) ADD_OP_ITEM(name) +#define UNARY_OP(name) ADD_OP_ITEM(name) + +static void RegisterAllNupharX86OpIRCreators(tvm_codegen::OpIRRegistry* op_ir_registry) { + LIST_ALL_X86_OPS() +} + +#undef ADD_OP_ITEM +#undef REDUCE_V_OP +#undef UNARY_OP + +// END: NupharTVM X86 IR creator classes + +// 2. Create Scheduler instances +// BEGIN: Nuphar Scheduler classes + +static void RegisterAllNupharSchedulers(tvm_codegen::TVMScheduleRegistry* sched_registry) { + // Add Generic TVM Rule schedules + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + + // Add Generic OpType schedules + sched_registry->Register( + std::move(std::make_unique())); + + // Add NupharX86 TVM Rule schedules + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + + // Add NupharX86 Tensorization schedules + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + + // Add NupharX86 OpType schedules + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + + // Add NupharX86 use count schedules + sched_registry->Register( + std::move(std::make_unique())); + sched_registry->Register( + std::move(std::make_unique())); + + // Add NupharX86 partial result schedules + sched_registry->Register( + std::move(std::make_unique())); +} + +// END: Nuphar Scheduler classes + +// 3. Create Weight layout instances +// BEGIN: Nuphar Weight Layouts classes +static void RegisterAllNupharWeightLayouts(tvm_codegen::WeightLayoutRegistry* layout_registry) { + layout_registry->Register( + std::move(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, 8))); + layout_registry->Register( + std::move(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT))); + layout_registry->Register( + std::move(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8))); + layout_registry->Register( + std::move(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8))); + layout_registry->Register( + std::move(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16))); +} + +// END: Nuphar Weight Layouts classes + +//// All Plugins for Nuphar provider +// 1. Plugin IR creator classes + +// BEGIN: Nuphar TVM X86 IR creator classes +#define ADD_OP_ITEM(name) \ + dispatcher->Register(#name, registry->Get(NUPHAR_TVM_X86_OP_IR_CREATOR_STRING(name))); + +#define REDUCE_V_OP(name) ADD_OP_ITEM(name) +#define UNARY_OP(name) ADD_OP_ITEM(name) + +static void RegisterNupharX86Dispatcher(const std::shared_ptr& builder, + const tvm_codegen::OpIRRegistry* registry) { + auto dispatcher = std::make_unique("OptypeNupharTVMX86Creators"); + LIST_ALL_X86_OPS() + builder->InsertDispatcher(std::move(dispatcher)); +} + +#undef ADD_OP_ITEM +#undef REDUCE_V_OP +#undef UNARY_OP +// END: Nuphar TVM X86 IR creator classes + +// 2 Plugin Scheduler classes + +// BEGIN: TVM rule Scheduler +static void RegisterNupharX86TVMRuleSchedulers(const std::shared_ptr& builder, + const tvm_codegen::TVMScheduleRegistry* registry) { + auto dispatcher = std::make_unique("NupharX86TVMRuleSchedulers"); + + // Register a scheduler for TVM External Tensor + dispatcher->Register(tvm_codegen::GetTVMOpRule(tvm_codegen::TVMOpRuleType::Extern), + registry->Get(TVM_SCHEDULER_STRING(Extern, NupharX86TVMRule))); + // Register a scheduler for TVM Reduce Tensor + dispatcher->Register(tvm_codegen::GetTVMOpRule(tvm_codegen::TVMOpRuleType::ComputeReduce), + registry->Get(TVM_SCHEDULER_STRING(Reduce, NupharX86TVMRule))); + + builder->InsertDispatcher(std::move(dispatcher)); +} +// END: TVM rule Scheduler + +// BEGIN: ORT OpType Scheduler +static void RegisterNupharX86OrtOpTypeSchedulers(const std::shared_ptr& builder, + const tvm_codegen::TVMScheduleRegistry* registry) { + auto dispatcher = std::make_unique("NupharX86OrtOpTypeSchedulers"); + + // Register a scheduler for Ort Softmax OpType + dispatcher->Register("Softmax", + registry->Get(TVM_SCHEDULER_STRING(Softmax, NupharX86OrtOpType))); + + dispatcher->Register("Split", + registry->Get(TVM_SCHEDULER_STRING(Split, NupharX86OrtOpType))); + + builder->InsertDispatcher(std::move(dispatcher)); +} +// END: ORT OpType Scheduler + +// BEGIN: Reuse Count Analysis Scheduler +static void RegisterNupharX86UseCountSchedulers(const std::shared_ptr& builder, + const tvm_codegen::TVMScheduleRegistry* registry) { + auto dispatcher = std::make_unique("NupharX86UseCountSchedulers"); + + // Register a scheduler for Reuse count > 1 + dispatcher->Register("True", + registry->Get(TVM_SCHEDULER_STRING(True, NupharX86UseCount))); + + // Register a scheduler for Reuse count <= 1 + dispatcher->Register("False", + registry->Get(TVM_SCHEDULER_STRING(False, NupharX86UseCount))); + + builder->InsertDispatcher(std::move(dispatcher)); +} +// END: Reuse Count Analysis Scheduler + +// BEGIN: Partial Result Scheduler +static void RegisterNupharX86PartialResultSchedulers(const std::shared_ptr& builder, + const tvm_codegen::TVMScheduleRegistry* registry) { + auto dispatcher = std::make_unique("NupharX86PartialResultSchedulers"); + dispatcher->Register("True", + registry->Get(TVM_SCHEDULER_STRING(True, NupharX86PartialResult))); + + builder->InsertDispatcher(std::move(dispatcher)); +} +// END: Partial Result Scheduler + +TVMCodeGenManager::TVMCodeGenManager() { + op_ir_registry_ = std::make_unique(); + layout_registry_ = std::make_unique(); + schedule_registry_ = std::make_unique(); +} + +void TVMCodeGenManager::Initialization() { + RegisterAllNupharX86OpIRCreators(op_ir_registry_.get()); + RegisterAllGenericOpIRCreators(op_ir_registry_.get()); + + RegisterAllNupharWeightLayouts(layout_registry_.get()); + RegisterAllNupharSchedulers(schedule_registry_.get()); +} + +// TODO Add isa support +void TVMCodeGenManager::SetCodeGenHandle(NupharCodeGenHandle* handle) { + // layout registry + handle->layout_registry = layout_registry_.get(); + + // Op IR creators + handle->op_ir_builder = + std::make_shared("Nuphar_Op_IR_Builder"); + RegisterNupharX86Dispatcher(handle->op_ir_builder, op_ir_registry_.get()); + RegisterGenericOrtOpTypeDispatcher(handle->op_ir_builder, op_ir_registry_.get()); + + // Schedulers + handle->schedule_builder = + std::make_shared("Nuphar_Schedule_Builder"); + + RegisterNupharX86TVMRuleSchedulers(handle->schedule_builder, schedule_registry_.get()); + RegisterNupharX86OrtOpTypeSchedulers(handle->schedule_builder, schedule_registry_.get()); + RegisterNupharX86UseCountSchedulers(handle->schedule_builder, schedule_registry_.get()); + RegisterNupharX86PartialResultSchedulers(handle->schedule_builder, schedule_registry_.get()); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/codegen_manager.h b/onnxruntime/core/providers/nuphar/compiler/codegen_manager.h new file mode 100644 index 0000000000000..75a88002fe9fb --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/codegen_manager.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" +#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h" +#include "core/codegen/passes/scheduler/tvm_schedule_builder.h" +#include "core/codegen/passes/weight_layout/weight_layout.h" +#include "core/providers/nuphar/compiler/nuphar_handle.h" + +namespace onnxruntime { +namespace nuphar { + +// TVMCodeGenManager contains all registries +// including 1) TVM IR builder registry +// 2) Weight layout transformer registry +// 3) TVM scheduler registry, etc. +// These registries include all applicable passes for specific arch +// AND might also include non-applicable passes, like passes for another arch. + +// TVMCodeGenManager keeps the ownerships of all registries, passes, +// and planners. + +// TVMCodeGenManager also sets NupharCodeGenHandle for a specific arch. + +class TVMCodeGenManager { + public: + TVMCodeGenManager(); + + // TODO add a list of condition to handle dynamic registration + void Initialization(); + + // TODO: add target as an input + void SetCodeGenHandle(NupharCodeGenHandle* handle); + + private: + std::unique_ptr op_ir_registry_; + std::unique_ptr layout_registry_; + std::unique_ptr schedule_registry_; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/func_info.cc b/onnxruntime/core/providers/nuphar/compiler/func_info.cc new file mode 100644 index 0000000000000..711a396a8de87 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/func_info.cc @@ -0,0 +1,562 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/func_info.h" + +#include "core/providers/nuphar/runtime/control_flow/scan_exec_ctx.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" +#include "core/codegen/common/common.h" +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" +#include + +// from onnxruntime_typeinf.cc, in global namespace +const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); + +namespace onnxruntime { +namespace nuphar { + +static void FillBasicFuncInfo(NupharFuncInfo* func_info, + nuphar::OrtSubgraphAllocationInfo* partition_info, + const nuphar::NupharSubgraphUnit& subgraph, + const NupharCodeGenCtx& codegen_ctx, + tvm::Target tvm_target, + tvm::runtime::PackedFunc packed_func, + const std::string& name) { + ORT_ENFORCE(nullptr != func_info); + ORT_ENFORCE(nullptr != partition_info); + + func_info->name = name; + func_info->packed_func = packed_func; + func_info->device_type = static_cast(tvm_target->device_type); + + int tvm_input_idx = 0; + int def_index = 0; + // Handle inputs + func_info->ort_input_count = subgraph.inputs.size(); + // Assign Input meta + for (auto& def : subgraph.inputs) { + // fill in allocator info + NupharFuncInfo::AllocatorMeta input_allocator; + if (partition_info->inputs.count(def->Name()) > 0) { + // if an input is from external + input_allocator.index = partition_info->inputs.at(def->Name()); + input_allocator.is_external = true; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } else if (partition_info->outputs.count(def->Name()) > 0) { + // if an input is from a previous real output + input_allocator.index = partition_info->outputs.at(def->Name()); + input_allocator.is_external = true; // a real output is always from external + func_info->ort_input_allocator_is_collided_output.push_back(true); + } else { + // else, an input is from an internal allocator + input_allocator.index = partition_info->CreateOrGetInternalAllocatorOffset(def->Name()); + input_allocator.is_external = false; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } + + func_info->ort_input_allocators.push_back(input_allocator); + + if (codegen_ctx.IsInitializer(def->Name())) { + ++def_index; + continue; // skip initializers + } + + // fill in func args + NupharFuncInfo::FuncArgMeta input_meta; + input_meta.dtype = ElementTypeFromProto(def->TypeAsProto()->tensor_type().elem_type()); + input_meta.ort_arg_index = def_index; + + // fill in shape info and symobolic info + for (int dim = 0; dim < gsl::narrow(ShapeRank(def)); ++dim) { + if (ShapeHasSymbol(def, dim)) { + input_meta.inferred_shape.push_back(Dimension_Unknown); + input_meta.dim_symbols.push_back(std::make_pair(gsl::narrow(dim), ShapeSymbol(def, dim))); + } else if (ShapeHasValue(def, dim)) { + input_meta.inferred_shape.push_back(ShapeValue(def, dim)); + } else { + input_meta.inferred_shape.push_back(Dimension_Unknown); + } + } + + func_info->input_metas.push_back(input_meta); + + ++tvm_input_idx; + ++def_index; + } + + // Handle initializers + // Initializer meta + std::vector& intializers = func_info->intializers; + // Assign Initializer meta + for (const auto& item : codegen_ctx.GetWeightLayoutMap()) { + const WeightLayoutCodegenInfo* layout_info = item.second.get(); + bool is_marshalled = layout_info->is_marshalled; + const Tensor* t = + is_marshalled ? layout_info->marshalled_initializer + : codegen_ctx.GetOrtInitializerTensor(item.first); + + intializers.push_back(t); + ++tvm_input_idx; + } + + // set input_count = the number of inputs + the number of initializers + func_info->func_input_count = gsl::narrow(tvm_input_idx); + + // Handle Outputs + + func_info->ort_output_count = subgraph.outputs.size(); + // Assign Output meta + int tvm_output_idx = 0; + std::unordered_map visited_output_def_indices; + def_index = 0; + for (auto& def : subgraph.outputs) { + // fill in allocator info + NupharFuncInfo::AllocatorMeta output_allocator; + if (partition_info->outputs.count(def->Name()) > 0) { + // if an output is from external + output_allocator.index = partition_info->outputs.at(def->Name()); + output_allocator.is_external = true; + } else { + // else, an output is from an internal allocator + output_allocator.index = partition_info->CreateOrGetInternalAllocatorOffset(def->Name()); + output_allocator.is_external = false; + } + + func_info->ort_output_allocators.push_back(output_allocator); + + // check output alias + const NodeArg* source_def = Promote(codegen_ctx.GetGraphStats()) + ->SourceDefOfOutputAlias(def); + + if (nullptr != source_def) { + // if def is an alias + auto key = GetKey(source_def); + if (visited_output_def_indices.count(key) != 0) { + // source_def has visisted ==> def is a duplicated output + // record the pair (dst of ort arg index, src of tvm func index) + func_info->ort_aliased_output_to_func_indices.emplace_back(def_index, + func_info->func_input_count + + visited_output_def_indices[key]); + + ++def_index; + continue; + } + // update visited_output_def_indices + visited_output_def_indices.insert(std::make_pair(key, gsl::narrow_cast(tvm_output_idx))); + } else { + auto key = GetKey(def); + if (visited_output_def_indices.count(key) != 0) { + // def has visisted ==> def is a duplicated output + // record the pair (dst of ort arg index, src of tvm func index) + func_info->ort_aliased_output_to_func_indices.emplace_back(def_index, + func_info->func_input_count + + visited_output_def_indices[key]); + + ++def_index; + continue; + } + visited_output_def_indices.insert(std::make_pair(key, gsl::narrow_cast(tvm_output_idx))); + } + + NupharFuncInfo::FuncArgMeta output_meta; + output_meta.dtype = ElementTypeFromProto(def->TypeAsProto()->tensor_type().elem_type()); + output_meta.ort_arg_index = def_index; + + // fill in shape info and symobolic info + for (int dim = 0; dim < gsl::narrow(ShapeRank(def)); ++dim) { + if (ShapeHasSymbol(def, dim)) { + auto p = std::make_pair(gsl::narrow(dim), ShapeSymbol(def, dim)); + output_meta.dim_symbols.push_back(p); + output_meta.inferred_shape.push_back(Dimension_Unknown); + } else if (ShapeHasValue(def, dim)) { + output_meta.inferred_shape.push_back(ShapeValue(def, dim)); + } else { + output_meta.inferred_shape.push_back(Dimension_Unknown); + } + } + + func_info->output_metas.push_back(output_meta); + ++def_index; + ++tvm_output_idx; + } + + // set output_count as the real output count + func_info->func_output_count = gsl::narrow_cast(tvm_output_idx); + + // set tvm type_codes + func_info->type_codes.resize(func_info->func_input_count + func_info->func_output_count, TVMTypeCode::kNDArrayContainer); +} + +static void FillScanExecInfo(NupharFuncInfo* func_info, + nuphar::OrtSubgraphAllocationInfo* partition_info, + const Node& node, + const NupharCodeGenCtx& codegen_ctx, + tvm::Target tvm_target, + tvm::runtime::PackedFunc packed_func, + const std::string& name) { + ORT_ENFORCE(nullptr != func_info); + ORT_ENFORCE(nullptr != partition_info); + + // create Scan control-flow info + auto scan_info = std::make_unique(); + + int64_t num_state_variables; + int64_t num_scan_inputs; + int64_t num_scan_outputs; + + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper attrs(&ctx); + + // extract num_scan_inputs + bool attr_is_ok = attrs.GetAttr("num_scan_inputs", &num_scan_inputs).IsOK(); + ORT_UNUSED_PARAMETER(attr_is_ok); + ORT_ENFORCE_DEBUG(attr_is_ok); + + auto subgraph = GetSubgraph(node); + ORT_ENFORCE(subgraph != nullptr); + size_t num_variadic_inputs = subgraph->GetInputs().size(); + size_t num_variadic_outputs = subgraph->GetOutputs().size(); + + num_state_variables = gsl::narrow(num_variadic_inputs) - num_scan_inputs; + num_scan_outputs = gsl::narrow(num_variadic_outputs) - num_state_variables; + + // Set ScanExecInfo's parameter count meta + scan_info->num_state_variables = num_state_variables; + scan_info->num_scan_inputs = num_scan_inputs; + scan_info->num_scan_outputs = num_scan_outputs; + scan_info->num_scan_implicit_inputs = gsl::narrow_cast(node.ImplicitInputDefs().size()); + + // ScanExecInfo's control flow Meta + std::vector& scan_input_forwards = scan_info->scan_input_forwards; + std::vector& scan_output_forwards = scan_info->scan_output_forwards; + std::vector& scan_input_axes = scan_info->scan_input_axes; + std::vector& scan_output_axes = scan_info->scan_output_axes; + + scan_input_forwards.resize(num_scan_inputs); + scan_output_forwards.resize(num_scan_outputs); + + // extract directions and axes + std::vector scan_input_directions; + std::vector scan_output_directions; + + // scan_input_directions + if (attrs.GetAttrs("scan_input_directions", scan_input_directions).IsOK()) { + ORT_ENFORCE(gsl::narrow_cast(scan_input_directions.size()) == num_scan_inputs, + "Number of entries in 'scan_input_directions ' was ", scan_input_directions.size(), + ". Must match 'num_scan_inputs' of ", num_scan_inputs); + ORT_ENFORCE(std::all_of(scan_input_directions.cbegin(), scan_input_directions.cend(), + [](int64_t i) { return i == 0 || + i == 1; }), + "Invalid values in 'scan_input_directions'. 0 == forward. 1 == reverse."); + } else { + // default to forward + scan_input_directions = std::vector(num_scan_inputs, 0); + } + + // scan_input_forwards + for (size_t i = 0; i < gsl::narrow(num_scan_inputs); ++i) { + scan_input_forwards[i] = scan_input_directions[i] == 0; + } + + // scan_output_directions + if (attrs.GetAttrs("scan_output_directions", scan_output_directions).IsOK()) { + ORT_ENFORCE(gsl::narrow_cast(scan_output_directions.size()) == num_scan_outputs, + "Number of entries in 'scan_output_directions ' was ", scan_output_directions.size(), + ". Must match 'num_scan_outputs' of ", num_scan_outputs); + ORT_ENFORCE(std::all_of(scan_output_directions.cbegin(), scan_output_directions.cend(), + [](int64_t i) { return i == 0 || + i == 1; }), + "Invalid values in 'scan_output_directions'. 0 == forward. 1 == reverse."); + } else { + // default to forward + scan_output_directions = std::vector(num_scan_outputs, 0); + } + + // scan_output_forwards + for (size_t i = 0; i < gsl::narrow(num_scan_outputs); ++i) { + scan_output_forwards[i] = scan_output_directions[i] == 0; + } + + // scan_input_axes + if (attrs.GetAttrs("scan_input_axes", scan_input_axes).IsOK()) { + ORT_ENFORCE(gsl::narrow_cast(scan_input_axes.size()) == num_scan_inputs, + "Number of entries in 'scan_input_axes ' was ", scan_input_axes.size(), + ". Must match 'num_scan_inputs' of ", num_scan_inputs); + + } else { + // default to axis 0 + scan_input_axes = std::vector(num_scan_inputs, 0); + } + + // scan_output_axes + if (attrs.GetAttrs("scan_output_axes", scan_output_axes).IsOK()) { + ORT_ENFORCE(gsl::narrow_cast(scan_output_axes.size()) == num_scan_outputs, + "Number of entries in 'scan_output_axes ' was ", scan_output_axes.size(), + ". Must match 'num_scan_outputs' of ", num_scan_outputs); + + } else { + // default to axis 0 + scan_output_axes = std::vector(num_scan_outputs, 0); + } + + // handle NupharFuncInfo + func_info->name = name; + func_info->packed_func = packed_func; + func_info->device_type = static_cast(tvm_target->device_type); + + int tvm_input_idx = 0; + // Handle state inputs & inputs + func_info->ort_input_count = num_variadic_inputs; + + // assign state inputs & inputs + for (size_t ort_input_idx = 0; ort_input_idx < num_variadic_inputs; ++ort_input_idx) { + // fill in allocator info + NupharFuncInfo::AllocatorMeta input_allocator; + const NodeArg* main_graph_def = node.InputDefs()[ort_input_idx]; + ORT_ENFORCE(nullptr != main_graph_def); + if (partition_info->inputs.count(main_graph_def->Name()) > 0) { + // if an input is from external + input_allocator.index = partition_info->inputs.at(main_graph_def->Name()); + input_allocator.is_external = true; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } else if (partition_info->outputs.count(main_graph_def->Name()) > 0) { + // if an input is from a previous real output + input_allocator.index = partition_info->outputs.at(main_graph_def->Name()); + input_allocator.is_external = true; // a real output is always from external + func_info->ort_input_allocator_is_collided_output.push_back(true); + } else { + // else, an input is from an internal allocator + input_allocator.index = partition_info->CreateOrGetInternalAllocatorOffset(main_graph_def->Name()); + input_allocator.is_external = false; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } + + func_info->ort_input_allocators.push_back(input_allocator); + + const NodeArg* def = subgraph->GetInputs()[ort_input_idx]; + ORT_ENFORCE(nullptr != def); + + if (ort_input_idx >= gsl::narrow(num_state_variables)) { + // initializer should only happen in real inputs, not in state inputs + if (codegen_ctx.IsInitializer(def->Name())) { + continue; // skip initializers + } + } + + NupharFuncInfo::FuncArgMeta input_meta; + input_meta.dtype = ElementTypeFromProto(def->TypeAsProto()->tensor_type().elem_type()); + input_meta.ort_arg_index = gsl::narrow_cast(ort_input_idx); + + // fill in shape info and symobolic info + for (int dim = 0; dim < gsl::narrow(ShapeRank(def)); ++dim) { + if (ShapeHasSymbol(def, dim)) { + auto p = std::make_pair(gsl::narrow(dim), ShapeSymbol(def, dim)); + input_meta.dim_symbols.push_back(p); + input_meta.inferred_shape.push_back(Dimension_Unknown); + } else if (ShapeHasValue(def, dim)) { + input_meta.inferred_shape.push_back(ShapeValue(def, dim)); + } else { + input_meta.inferred_shape.push_back(Dimension_Unknown); + } + } + + func_info->input_metas.push_back(input_meta); + ++tvm_input_idx; + } + + size_t ort_input_idx = num_variadic_inputs; + // Handle implicit inputs + for (const NodeArg* def : node.ImplicitInputDefs()) { + NupharFuncInfo::AllocatorMeta input_allocator; + if (partition_info->inputs.count(def->Name()) > 0) { + // if an input is from external + input_allocator.index = partition_info->inputs.at(def->Name()); + input_allocator.is_external = true; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } else if (partition_info->outputs.count(def->Name()) > 0) { + // if an input is from a previous real output + input_allocator.index = partition_info->outputs.at(def->Name()); + input_allocator.is_external = true; + func_info->ort_input_allocator_is_collided_output.push_back(true); + } else { + // else, an input is from an internal allocator + input_allocator.index = partition_info->CreateOrGetInternalAllocatorOffset(def->Name()); + input_allocator.is_external = false; + func_info->ort_input_allocator_is_collided_output.push_back(false); + } + + func_info->ort_input_allocators.push_back(input_allocator); + + // skip initializers + if (codegen_ctx.IsInitializer(def->Name())) { + ++ort_input_idx; + continue; // skip initializers + } + + NupharFuncInfo::FuncArgMeta input_meta; + input_meta.dtype = ElementTypeFromProto(def->TypeAsProto()->tensor_type().elem_type()); + input_meta.ort_arg_index = gsl::narrow_cast(ort_input_idx); + + std::vector> symbols; + for (int dim = 0; dim < gsl::narrow(ShapeRank(def)); ++dim) { + if (ShapeHasSymbol(def, dim)) { + auto p = std::make_pair(gsl::narrow(dim), ShapeSymbol(def, dim)); + input_meta.dim_symbols.push_back(p); + input_meta.inferred_shape.push_back(Dimension_Unknown); + } else if (ShapeHasValue(def, dim)) { + input_meta.inferred_shape.push_back(ShapeValue(def, dim)); + } else { + input_meta.inferred_shape.push_back(Dimension_Unknown); + } + } + func_info->input_metas.push_back(input_meta); + ++tvm_input_idx; + ++ort_input_idx; + } + + // Handle initializers + // Initializer meta + std::vector& intializers = func_info->intializers; + + // Assign Initializer meta + for (const auto& item : codegen_ctx.GetWeightLayoutMap()) { + const WeightLayoutCodegenInfo* layout_info = item.second.get(); + + bool is_marshalled = layout_info->is_marshalled; + const Tensor* t = + is_marshalled ? layout_info->marshalled_initializer + : codegen_ctx.GetOrtInitializerTensor(item.first); + + intializers.push_back(t); + ++tvm_input_idx; + } + + // set input_count = the number of inputs (real inputs + state inputs) + the number of initializers + func_info->func_input_count = gsl::narrow(tvm_input_idx); + + // Handle State Outputs and Outputs + func_info->ort_output_count = num_variadic_outputs; + + // Since in Scan, we only allow state using output's memory during Execution, not the other around. + // When one input and one state are aliased, the kept one can only be the input. + // Therefore, we do alias detection starting from inputs first. + std::unordered_map visited_output_def_indices; + for (size_t ort_output_idx = gsl::narrow(num_state_variables); ort_output_idx < num_variadic_outputs; ++ort_output_idx) { + const NodeArg* def = subgraph->GetOutputs()[ort_output_idx]; + ORT_ENFORCE(nullptr != def); + const NodeArg* source_def = Promote(codegen_ctx.GetGraphStats()) + ->SourceDefOfOutputAlias(def); + if (nullptr != source_def) { + auto key = GetKey(source_def); + ORT_ENFORCE(visited_output_def_indices.count(key) == 0, + "Scan has alias btw two inputs. Nuphar only support aliasing btw state and output in Scan"); + visited_output_def_indices.insert(std::make_pair(key, gsl::narrow(ort_output_idx))); + } else { + auto key = GetKey(def); + visited_output_def_indices.insert(std::make_pair(key, gsl::narrow(ort_output_idx))); + } + } + + // assign state outputs and outputs + size_t tvm_output_idx = 0; + std::unordered_map visited_output_state_func_indices; + for (size_t ort_output_idx = 0; ort_output_idx < num_variadic_outputs; ++ort_output_idx) { + // fill in allocator info + NupharFuncInfo::AllocatorMeta output_allocator; + const NodeArg* main_graph_def = node.OutputDefs()[ort_output_idx]; + ORT_ENFORCE(nullptr != main_graph_def); + if (partition_info->outputs.count(main_graph_def->Name()) > 0) { + output_allocator.index = partition_info->outputs.at(main_graph_def->Name()); + output_allocator.is_external = true; + } else { + output_allocator.index = partition_info->CreateOrGetInternalAllocatorOffset(main_graph_def->Name()); + output_allocator.is_external = false; + } + func_info->ort_output_allocators.push_back(output_allocator); + + // perform alias analysis + const NodeArg* def = subgraph->GetOutputs()[ort_output_idx]; + ORT_ENFORCE(nullptr != def); + const NodeArg* source_def = Promote(codegen_ctx.GetGraphStats()) + ->SourceDefOfOutputAlias(def); + + // Determine alias btw output and state output + auto key = source_def != nullptr ? GetKey(source_def) : GetKey(def); + + int ort_arg_index = gsl::narrow_cast(ort_output_idx); + if (ort_output_idx < gsl::narrow(num_state_variables)) { + // if ort_output_idx is a state output + if (visited_output_def_indices.count(key) != 0) { + // If state output is an alias + // record i_output for the lookup of the aliased output later + visited_output_state_func_indices.insert(std::make_pair(key, gsl::narrow(func_info->func_input_count + tvm_output_idx))); + + // also record ort_aliased_output_to_func_indices + func_info->ort_aliased_output_to_func_indices.push_back(std::make_pair(gsl::narrow(ort_output_idx), + func_info->func_input_count + tvm_output_idx)); + + scan_info->state_to_output_indices.push_back(visited_output_def_indices[key] - gsl::narrow_cast(num_state_variables)); + // override ort_arg_index using the output index + ort_arg_index = visited_output_def_indices[key]; + } else { + // the state output not aliased(no scan output shares with it) + scan_info->state_to_output_indices.push_back(NupharFuncInfo::Index_NonAliasedOutput); + } + } else { + // if ort_output_idx is an output + if (visited_output_state_func_indices.count(key) != 0) { + if (source_def != nullptr) { + // skip a duplicated output, since it was counted in the duplicated state output previously + continue; + } + } + } + + NupharFuncInfo::FuncArgMeta output_meta; + output_meta.dtype = ElementTypeFromProto(def->TypeAsProto()->tensor_type().elem_type()); + output_meta.ort_arg_index = ort_arg_index; + + // shape and symbols + for (int dim = 0; dim < gsl::narrow(ShapeRank(def)); ++dim) { + if (ShapeHasSymbol(def, dim)) { + auto p = std::make_pair(gsl::narrow(dim), ShapeSymbol(def, dim)); + output_meta.dim_symbols.push_back(p); + output_meta.inferred_shape.push_back(Dimension_Unknown); + } else if (ShapeHasValue(def, dim)) { + output_meta.inferred_shape.push_back(ShapeValue(def, dim)); + } else { + output_meta.inferred_shape.push_back(Dimension_Unknown); + } + } + func_info->output_metas.push_back(output_meta); + ++tvm_output_idx; + } + + // set output_count as the real output count + func_info->func_output_count = tvm_output_idx; + + // set tvm type_codes + func_info->type_codes.resize(func_info->func_input_count + func_info->func_output_count, TVMTypeCode::kNDArrayContainer); + + // set control-flow info + func_info->cf_info = std::move(scan_info); +} + +void FillNupharFuncInfo(NupharFuncInfo* func_info, + nuphar::OrtSubgraphAllocationInfo* partition_info, + const nuphar::NupharSubgraphUnit& subgraph, + const NupharCodeGenCtx& codegen_ctx, + tvm::Target tvm_target, + tvm::runtime::PackedFunc packed_func, + const std::string& name) { + if (subgraph.nodes.front()->OpType() == "Scan") { + FillScanExecInfo(func_info, partition_info, *subgraph.nodes.front(), codegen_ctx, tvm_target, packed_func, name); + return; + } + + FillBasicFuncInfo(func_info, partition_info, subgraph, codegen_ctx, tvm_target, packed_func, name); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/func_info.h b/onnxruntime/core/providers/nuphar/compiler/func_info.h new file mode 100644 index 0000000000000..6b2780e6657de --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/func_info.h @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/common/common.h" +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/tensor.h" +#include "core/graph/graph.h" +#include "core/providers/nuphar/common/nuphar_subgraph.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" + +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +enum class ControlFlowInfoType : unsigned int { + Scan = 1, +}; + +// abstract class for control flow info +struct ControlFlowInfo { + private: + ControlFlowInfoType type; + + public: + ControlFlowInfo(ControlFlowInfoType _type) : type(_type) {} + + virtual ~ControlFlowInfo() = default; + + DYN_PROMOTE_BASE(ControlFlowInfo, ControlFlowInfoType, type) +}; + +// Add Promote support for ControlFlowInfo +// Note here we need to use DYN_PROMOTE instread of DYNAMIC_PROMOTE +// since ControlFlowInfo is a critical path +DYN_PROMOTE(ControlFlowInfo) + +// NupharFuncInfo holds tvm::runtime::PackedFunc (the generated function) +// And corresponding static meta information to call it, like number of argument and offset +// Note NupharFuncInfo includes ONLY parameters from codegen +// but DOES NOT include any runtime information. + +// The owner of NupharFuncInfo is currently NupharKernelState. +// NupharFuncInfo is created in NupharCompiler and is consumed by ExecBlock +// Note all of vectors use numbers of PackedFunc's parameters as vector bounds +// (meaning vector.size() == numbers of PackedFunc's parameters) +// except those denoted with ort, which use numbers of ort op's parameters as vector bounds. +// -1 might be inserted a bubble to keep positions and sizes for later lookup. +struct NupharFuncInfo { + // speicial value for *_func_indices + enum : int { + Index_NonAliasedOutput = -1, + }; + + // PackedFunc name + std::string name; + + // PackedFunc + tvm::runtime::PackedFunc packed_func; + + // TVM DLDevice + DLDeviceType device_type; + + struct FuncArgMeta { + MLDataType dtype; + // shapes with dimensions statically know or inferred at compile time + // symbolic dim would have Dimension_Unknown and will be patched at runtime + std::vector inferred_shape; + std::vector> dim_symbols; + int ort_arg_index; + }; + + std::vector input_metas; + std::vector output_metas; + std::vector> ort_aliased_output_to_func_indices; // A pair of (Ort dst index, TVM src index) + + struct AllocatorMeta { + int index; + bool is_external; + }; + + std::vector ort_input_allocators; + std::vector ort_output_allocators; + + // Note an input can be also an external output. + // It is due to NodeArg can be used by Nodes in + // and out of a subgraph at the same time. + // When it happens, we need to label it as a collided output, + // and record that external output allocator index. + std::vector ort_input_allocator_is_collided_output; + + // initializers meta + std::vector intializers; + + // Note the total arg number == input_count + output_count + size_t func_input_count; // input_count == real inputs + initializers + size_t func_output_count; // real outputs + + // tvm args (including input and outputs ) + std::vector type_codes; + + // control-flow info for the generated function + std::unique_ptr cf_info; + + size_t ort_input_count; + size_t ort_output_count; +}; + +void FillNupharFuncInfo(NupharFuncInfo* func_info, + nuphar::OrtSubgraphAllocationInfo* partition_info, + const nuphar::NupharSubgraphUnit& subgraph, + const NupharCodeGenCtx& codegen_ctx, + tvm::Target tvm_target, + tvm::runtime::PackedFunc packed_func, + const std::string& name); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/initializer_info.h b/onnxruntime/core/providers/nuphar/compiler/initializer_info.h new file mode 100644 index 0000000000000..b2fd47829828e --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/initializer_info.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/tensor.h" +#include + +// TODO: move to nuphar +namespace onnxruntime { +namespace nuphar { + +// TODO: move it to weight layout place +struct WeightLayoutCodegenInfo { + const Tensor* marshalled_initializer = nullptr; // TODO: change it to unique_ptr + std::string layout = ""; // layout name + tvm::Tensor marshalled_tensor; + tvm::Tensor unmarshalled_tensor; + bool is_marshalled; + + WeightLayoutCodegenInfo(const tvm::Tensor& tvm_tensor) + : marshalled_tensor(tvm_tensor), unmarshalled_tensor(tvm_tensor), is_marshalled(false) {} +}; + +struct InitializerInfo { + const Tensor* original_initializer = nullptr; // original ort tensor + std::unique_ptr layout_info = nullptr; + + InitializerInfo(const Tensor* tensor) : original_initializer(tensor) {} +}; + +using InitializerMap = std::map; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc new file mode 100644 index 0000000000000..1c9c4dae9c39b --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "nuphar_codegen_ctx.h" + +#include "core/codegen/common/common.h" +#include "core/codegen/common/utils.h" +#include "core/codegen/mti/mti_tvm_utils.h" // TODO: remove this after decoupling layout compile and run +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" +#include "core/codegen/passes/utils/ort_tvm_utils.h" // TODO: remove this after decoupling layout compile and run +#include // TODO: remove this after decoupling layout compile and run + +#include "core/providers/nuphar/common/nuphar_tvm_utils.h" + +namespace onnxruntime { +namespace nuphar { + +NupharCodeGenCtx::NupharCodeGenCtx( + const Node& node, + const std::map& initializers, + std::unordered_map>& global_generated_initializers, + const NupharCodeGenHandle* handle) + : CodeGenContext(handle), + nuphar_handle_(handle), + initializers_(initializers), + global_generated_initializers_(global_generated_initializers) { + // construct graph_stats + graph_stats_ = std::make_unique(nuphar_handle_->shape_inference); +} + +NupharCodeGenCtx::NupharCodeGenCtx( + const nuphar::NupharSubgraphUnit& subgraph, + std::unordered_map>& global_generated_initializers, + const NupharCodeGenHandle* handle) + : CodeGenContext(handle), + nuphar_handle_(handle), + initializers_(subgraph.initializers), + global_generated_initializers_(global_generated_initializers) { + graph_stats_ = std::make_unique(nuphar_handle_->shape_inference); + Promote(graph_stats_)->Evaluate(subgraph); +} + +// This is a temp function before we decouple weight layout compilation and run +// This will be moved. +// TODO: remove this. +static tvm::runtime::PackedFunc LowerLayoutFunc(const tvm_codegen::WeightLayout* layout) { + tvm::Array inputs; + tvm::Array outputs; + + layout->CreateLayoutMarshallingTVMOp(inputs, outputs); + + auto config = tvm::build_config(); + config->disable_select_rewriting = true; + auto S = tvm::create_schedule({outputs[0]->op}); + S[outputs[0]->op].compute_root(); + + std::string func_name = layout->Name() + "_marshall"; + + tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name); + + if (cached_func == nullptr) { + auto lowered = tvm::lower(S, {inputs[0], outputs[0]}, func_name, {}, config); + auto module = tvm::build(lowered, tvm::target::llvm(), tvm::Target(), config); + tvm_codegen::DumpTVMModuleToFile(func_name, module); + nuphar::SaveTVMModuleToCache(func_name, module); + cached_func = module.GetFunction(func_name); + } + return cached_func; +} + +// This is a temp function before we decouple weight layout compilation and run. +// This will be moved. +// TODO: remove this. +static const Tensor* Marshalling( + const std::string& initializer_name, + std::unordered_map>& global_generated_initializers, + const Tensor* original_initializer, + const tvm_codegen::WeightLayout* layout_ptr, + WeightLayoutCtx& ctx_layout, + AllocatorPtr allocator) { + tvm::runtime::PackedFunc packed_func; + + const std::string& layout_key = layout_ptr->Name(); + if (ctx_layout.weight_layout_to_packed_func.count(layout_key) == 0) { + packed_func = LowerLayoutFunc(layout_ptr); + ctx_layout.weight_layout_to_packed_func.insert(std::make_pair(layout_key, packed_func)); + } else { + packed_func = ctx_layout.weight_layout_to_packed_func[layout_key]; + } + + std::vector marshalled_shape = layout_ptr->ToActualShape(original_initializer); + auto marshalled_size = TotalSize(marshalled_shape); + auto byte_size = original_initializer->DataType()->Size(); + + std::unique_ptr out_ptr; + void* p_data = allocator->Alloc(marshalled_size * byte_size); + out_ptr = std::make_unique( + original_initializer->DataType(), + TensorShape(marshalled_shape), + p_data, + allocator->Info()); + + global_generated_initializers.emplace(initializer_name, std::move(out_ptr)); + + int num_args = 2; + DLContext tvm_ctx{kDLCPU, 0}; + std::vector lvalues(num_args); + std::vector tvm_tensors(num_args); + + // input + const auto& tensor_shape = original_initializer->Shape(); + auto input_shape = tensor_shape.GetDims(); + if (input_shape.empty()) + input_shape.push_back(1); + const void* input_data = original_initializer->DataRaw(); + DLDataType tvm_dtype = tvm_codegen::ToTvmDLDataType(original_initializer->DataType()); + + tvm_tensors[0] = {const_cast(input_data), tvm_ctx, + gsl::narrow_cast(input_shape.size()), tvm_dtype, + input_shape.data(), nullptr, 0}; + lvalues[0].v_handle = &(tvm_tensors[0]); + + // output + tvm_tensors[1] = {p_data, tvm_ctx, + gsl::narrow_cast(marshalled_shape.size()), tvm_dtype, + marshalled_shape.data(), nullptr, 0}; + lvalues[1].v_handle = &(tvm_tensors[1]); + + auto types_code = std::vector(num_args, kNDArrayContainer); + tvm::TVMArgs tvm_args(lvalues.data(), types_code.data(), num_args); + tvm::TVMRetValue rvalue; + packed_func.CallPacked(tvm_args, &rvalue); + return global_generated_initializers.at(initializer_name).get(); +} + +// on the fly WeightLayout transformer +tvm::Tensor NupharCodeGenCtx::ApplyWeightLayout( + const std::string& layout_key, + const std::string& initializer_name, + const tvm::Tensor& X, + bool returnMarshalled) { + tvm::Tensor marshalled; + ORT_ENFORCE(IsInitializer(initializer_name)); + auto layout_info = GetWeightLayoutInfo(initializer_name); + ORT_ENFORCE(nullptr != layout_info); + + const Tensor* original_initializer = GetOrtInitializerTensor(initializer_name); + + auto layout_ptr = nuphar_handle_->layout_registry->Get(layout_key); + ORT_ENFORCE(nullptr != layout_ptr); + + // check whether the weight is applied layout marshalling + if (nullptr == layout_info->marshalled_initializer) { + ORT_ENFORCE(!layout_info->is_marshalled); // initializer should not have been marshalled before + + // TODO: change to delayed call + layout_info->layout = layout_ptr->Name(); + + // TODO: change to delayed call + layout_info->marshalled_initializer = + Marshalling(initializer_name, + global_generated_initializers_, + original_initializer, + layout_ptr, + weight_layout_ctx_, + nuphar_handle_->allocator); + + layout_info->marshalled_tensor = tvm::placeholder(layout_ptr->ToActualShape(X), X->dtype, initializer_name + "_marshalled"); + layout_info->unmarshalled_tensor = tvm::compute( + X->shape, + [&](const tvm::Array& nominal_coord) { + tvm::Array cc; + for (auto v : nominal_coord) + cc.push_back(v); + + auto coord_trans_func = layout_ptr->ToActual(X); + return layout_info->marshalled_tensor(coord_trans_func(cc)); + }, + initializer_name + "_unmarshalled"); + + layout_info->is_marshalled = true; + + } else { + ORT_ENFORCE(layout_ptr->Name() == layout_info->layout); + } + + if (returnMarshalled) { + return layout_info->marshalled_tensor; + } + return layout_info->unmarshalled_tensor; +} + +const NupharSubgraphUnitStats* NupharCodeGenCtx::GetGraphStats() const { + return graph_stats_.get(); +} + +bool NupharCodeGenCtx::IsInitializer(const std::string& name) const { + return initializers_.count(name) > 0; +} + +const Tensor* NupharCodeGenCtx::GetOrtInitializerTensor(const std::string& name) const { + if (IsInitializer(name)) + return initializers_.at(name); + return nullptr; +} + +WeightLayoutCodegenInfo* NupharCodeGenCtx::GetWeightLayoutInfo(const std::string& name) { + if (initializer_layouts_.count(name) > 0) + return initializer_layouts_.at(name).get(); + return nullptr; +} + +const WeightLayoutCodegenInfo* NupharCodeGenCtx::GetWeightLayoutInfo(const std::string& name) const { + if (initializer_layouts_.count(name) > 0) + return initializer_layouts_.at(name).get(); + return nullptr; +} + +void NupharCodeGenCtx::CreateWeightLayoutInfo(const std::string& name, const tvm::Tensor& tensor) { + ORT_ENFORCE(initializer_layouts_.count(name) == 0); + initializer_layouts_.emplace(name, std::move(std::make_unique(tensor))); +} + +const std::map>& NupharCodeGenCtx::GetWeightLayoutMap() const { + return initializer_layouts_; +} + +void NupharCodeGenCtx::RecordTensorToNode(const tvm::Tensor& t, const Node* node) { + // Insert tvm::Tensor and Node to the lookup table + // But bypass it when node is a output alias + if (!Promote(graph_stats_)->IsOutputAlias(node)) + tvm_tensor_to_node_lookup_.insert(std::make_pair(t->op.get(), node)); +} + +const Node* NupharCodeGenCtx::FindNode(const tvm::Tensor& t) const { + auto p = tvm_tensor_to_node_lookup_.find(t->op.get()); + if (p != tvm_tensor_to_node_lookup_.end()) + return p->second; + return nullptr; +} + +const NupharCodeGenHandle* NupharCodeGenCtx::GetCodeGenHandle() const { + return nuphar_handle_; +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.h b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.h new file mode 100644 index 0000000000000..69ffa04adb3cd --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.h @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/codegen/passes/utils/codegen_context.h" +#include "core/common/common.h" +#include "core/graph/graph.h" +#include "core/providers/nuphar/common/analysis/graph_stats.h" +#include "core/providers/nuphar/common/nuphar_subgraph.h" +#include "core/providers/nuphar/compiler/initializer_info.h" +#include "core/providers/nuphar/compiler/nuphar_handle.h" + +#include + +namespace onnxruntime { +namespace nuphar { + +// Nuphar Tensor Context +struct TVMTensorCtx { + std::map inputs; + std::map> ops; + std::map> input_from; + + bool Lookup(const NodeArg* def, tvm::Tensor& tensor) { + const std::string& def_name = def->Name(); + auto iter = inputs.find(def_name); + if (iter != inputs.end()) { + tensor = iter->second; + return true; + } + + auto iter_out_index = input_from.find(def_name); + + if (iter_out_index == input_from.end()) { + return false; + } + + const Node* from_node = iter_out_index->second.first; + size_t index = iter_out_index->second.second; + auto iter_op = ops.find(from_node); + ORT_ENFORCE(iter_op != ops.end()); + tensor = iter_op->second[index]; + return true; + } + + const tvm::Tensor + Lookup(const NodeArg* def) const { + const std::string& def_name = def->Name(); + auto iter = inputs.find(def_name); + if (iter != inputs.end()) { + return iter->second; + } + + auto iter_out_index = input_from.find(def_name); + + ORT_ENFORCE(iter_out_index != input_from.end()); + + const Node* from_node = iter_out_index->second.first; + size_t index = iter_out_index->second.second; + auto iter_op = ops.find(from_node); + ORT_ENFORCE(iter_op != ops.end()); + return iter_op->second[index]; + } +}; + +struct WeightLayoutCtx { + //std::map initializer_to_weight_layout; // unused yet. This is for decoupling weight layout compile and run + std::unordered_map weight_layout_to_packed_func; +}; + +// NupharCodeGenCtx is Nuphar-specific CodeGenContext +class NupharCodeGenCtx : public tvm_codegen::CodeGenContext { + public: + NupharCodeGenCtx(const Node& node, + const std::map& initializers, + std::unordered_map>& global_generated_initializers, + const NupharCodeGenHandle* handle); + + NupharCodeGenCtx(const nuphar::NupharSubgraphUnit& subgraph, + std::unordered_map>& global_generated_initializers, + const NupharCodeGenHandle* handle); + + virtual ~NupharCodeGenCtx() = default; + + const NupharSubgraphUnitStats* GetGraphStats() const; + + bool IsInitializer(const std::string& name) const; + const Tensor* GetOrtInitializerTensor(const std::string& name) const; + WeightLayoutCodegenInfo* GetWeightLayoutInfo(const std::string& name); + const WeightLayoutCodegenInfo* GetWeightLayoutInfo(const std::string& name) const; + void CreateWeightLayoutInfo(const std::string& name, const tvm::Tensor& tensor); + const std::map>& GetWeightLayoutMap() const; + + // On-the-fly apply an existing layout + tvm::Tensor ApplyWeightLayout( + const std::string& layout_key, + const std::string& initializer_name, + const tvm::Tensor& X, + bool returnMarshalled); + + void RecordTensorToNode(const tvm::Tensor& t, const Node* node); + const Node* FindNode(const tvm::Tensor& t) const; + + const NupharCodeGenHandle* GetCodeGenHandle() const; + + // TODO remove this after decoupling compiler and runtime of WeightLayout + template + IAllocatorUniquePtr AllocateT(size_t size) const { return IAllocator::MakeUniquePtr(nuphar_handle_->allocator, size); } + // TODO remove this after decoupling compiler and runtime of WeightLayout + IAllocatorUniquePtr Allocate(size_t size) const { return AllocateT(size); } + + // Keep for CodeGenContext + TVMTensorCtx& GetTVMTensorCtx() { + return tvm_tensor_ctx_; + } + + // Keep for CodeGenContext + const TVMTensorCtx& GetTVMTensorCtx() const { + return tvm_tensor_ctx_; + } + + private: + std::unique_ptr graph_stats_; + + const NupharCodeGenHandle* nuphar_handle_; + + const std::map& initializers_; + + // A table from tvm::Tensor (its unchanged source tvm::Node*) to ORT Node + std::unordered_map tvm_tensor_to_node_lookup_; + + // All TVM Tensor and correponidng shape context + TVMTensorCtx tvm_tensor_ctx_; + + // local copy + std::map> initializer_layouts_; + + std::unordered_map>& global_generated_initializers_; + + // all layouts + WeightLayoutCtx weight_layout_ctx_; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc new file mode 100644 index 0000000000000..61716f1a1f40e --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/nuphar_compiler.h" + +#include "core/codegen/common/profile.h" +#include "core/codegen/common/settings.h" +#include "core/codegen/mti/mti_tvm_utils.h" +#include "core/codegen/passes/utils/ort_tvm_utils.h" +#include "core/mlas/inc/mlas.h" +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" +#include "core/providers/nuphar/common/nuphar_settings.h" +#include "core/providers/nuphar/common/nuphar_tvm_utils.h" +#include "core/providers/nuphar/compiler/nuphar_handle.h" +#include "core/providers/nuphar/compiler/nuphar_op_ir_builder.h" +#include "core/providers/nuphar/compiler/nuphar_schedule_builder.h" + +namespace onnxruntime { +namespace nuphar { + +static void HandleAllOutputs( + const std::vector& outputs, + tvm::Array& tvm_args, + tvm::Array& tvm_outputs, + const NupharCodeGenCtx& context) { + // find out all outputs + std::set visited_alias_def; + auto add_tvm_arg_and_output = [&](const onnxruntime::NodeArg* def) { + auto& tvm_tensor = context.GetTVMTensorCtx().Lookup(def); + tvm_args.push_back(tvm_tensor); + tvm_outputs.push_back(tvm_tensor); + }; + + for (const NodeArg* def : outputs) { + const NodeArg* input_def = Promote(context.GetGraphStats())->SourceDefOfOutputAlias(def); + if (input_def) { + auto key = GetKey(input_def); + if (visited_alias_def.count(key) == 0) { + visited_alias_def.insert(key); + add_tvm_arg_and_output(input_def); + } + } else { + auto key = GetKey(def); + if (visited_alias_def.count(key) == 0) { + visited_alias_def.insert(key); + add_tvm_arg_and_output(def); + } + } + } +} + +// Constructor for Node +// This is mainly for single node support +// For multiple subgraph support, please call the next constructor +NupharCompiler::NupharCompiler(const Node& node, + const std::map& initializer, + std::unordered_map>& generated_initializers, + const NupharCodeGenHandle* handle) + : num_initializers_in_graph_inputs_(0), + context_(node, initializer, generated_initializers, handle) {} + +NupharCompiler::NupharCompiler(const nuphar::NupharSubgraphUnit& subgraph, + std::unordered_map>& generated_initializers, + const NupharCodeGenHandle* handle) + : num_initializers_in_graph_inputs_(0), + context_(subgraph, generated_initializers, handle) {} + +Status NupharCompiler::Build(const nuphar::NupharSubgraphUnit& subgraph) { + if (subgraph.nodes.front()->OpType() == "Scan") { + return BuildSubgraph(*subgraph.nodes.front()); + } + + tvm_args_ = tvm::Array(); + tvm_outputs_ = tvm::Array(); + + ORT_RETURN_IF_ERROR(CreateTVMIR(subgraph, context_)); + + // fill in all non-initializer inputs + num_initializers_in_graph_inputs_ = 0; + for (auto& def : subgraph.inputs) { + if (context_.IsInitializer(def->Name())) { + ++num_initializers_in_graph_inputs_; + } else { + tvm_args_.push_back(context_.GetTVMTensorCtx().Lookup(def)); + } + } + + // fill in all initializers + for (const auto& item : context_.GetWeightLayoutMap()) { + const WeightLayoutCodegenInfo* layout_info = item.second.get(); + tvm_args_.push_back(layout_info->marshalled_tensor); + } + + // find out all outputs, and save the output shapes + HandleAllOutputs(subgraph.outputs, tvm_args_, tvm_outputs_, context_); + + return Status::OK(); +} + +// BuildSubgraph drive a graph traversal that calls CreateInput and CreateOutputs metioned above for a subgraph. +// And collect args among nodes. +// We need another API other than Build, because name mismatching +Status NupharCompiler::BuildSubgraph(const Node& node) { + tvm_args_ = tvm::Array(); + tvm_outputs_ = tvm::Array(); + + auto subgraph = GetSubgraph(node); + + ORT_RETURN_IF_ERROR(CreateTVMIR(GraphViewer(*subgraph), context_, /*use_placeholder_for_input*/ true)); + + num_initializers_in_graph_inputs_ = 0; + // fill in all non-initializer inputs + + for (const auto& input : subgraph->GetInputs()) { + if (context_.IsInitializer(input->Name())) { + ++num_initializers_in_graph_inputs_; + } else { + tvm_args_.push_back(context_.GetTVMTensorCtx().Lookup(input)); + } + } + + // fill in implicit inputs + for (const auto& input : node.ImplicitInputDefs()) { + if (context_.IsInitializer(input->Name())) { + ++num_initializers_in_graph_inputs_; + } else { + tvm_args_.push_back(context_.GetTVMTensorCtx().Lookup(input)); + } + } + + // fill in all initializers + for (const auto& item : context_.GetWeightLayoutMap()) { + const WeightLayoutCodegenInfo* layout_info = item.second.get(); + tvm_args_.push_back(layout_info->marshalled_tensor); + } + + // find out all outputs + HandleAllOutputs(subgraph->GetOutputs(), tvm_args_, tvm_outputs_, context_); + + return Status::OK(); +} + +tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc( + const std::string& func_name, + tvm::Target tvm_target, + tvm::Target tvm_host_target, + const tvm::BuildConfig& config, + const std::string& subgraph_type, + const std::string& subgraph_name) { + // TODO: refactor the following logic for both JIT-caching and AOT support + // JIT-caching and AOT are mutual exclusive. + // Change it by not always saving a compiled func unless it is in JIT-Caching model. + // In AOT, there should be another member func explicitly loading + tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name); + if (cached_func == nullptr) { + codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); + + if (settings.HasOption(kNupharCacheForceNoJIT)) { + if (settings.OptionMatches(kNupharCacheForceNoJIT, "on")) { + ORT_THROW("Force not using JIT code!"); + } + } + + tvm::Schedule tvm_schedule = CreateSchedule(tvm_outputs_, context_); + std::unordered_map binds; + tvm::Array lowered = tvm::lower(tvm_schedule, tvm_args_, func_name, binds, config); + + if (settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpLower)) { + if (settings.OptionMatches(codegen::CodeGenSettings::kCodeGenDumpLower, "verbose") || + settings.OptionMatches(codegen::CodeGenSettings::kCodeGenDumpLower, subgraph_type)) { + for (const auto& func : lowered) + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "[CODEGEN_DUMP_LOWER] Dumping lowered func: " << func << std::endl + << func->body; + } else if (settings.OptionMatches(codegen::CodeGenSettings::kCodeGenDumpLower, "concise")) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "[CODEGEN_DUMP_LOWER] Subgraph Type: " + << subgraph_type << ", name: " << subgraph_name + << " #lowered funcs: " << lowered.size() << std::endl; + } + } + + tvm::runtime::Module module = tvm::build(lowered, tvm_target, tvm_host_target, config); + tvm_codegen::DumpTVMModuleToFile(func_name, module); + nuphar::SaveTVMModuleToCache(func_name, module); + cached_func = module.GetFunction(func_name); + } + + return cached_func; +} + +static tvm::BuildConfig CreateConfig(const Node& node, + bool allow_unaligned_buffers) { + tvm::BuildConfig config = tvm::build_config(); + config->disable_select_rewriting = true; + + if (allow_unaligned_buffers) { + config->data_alignment = 1; // aligned to 1 + } else { + config->data_alignment = gsl::narrow(MlasGetPreferredBufferAlignment()); + } + + config->restricted_func = true; + return config; +} + +// Lower compiles the tvm::Tensor to a function +Status NupharCompiler::Lower(const nuphar::NupharSubgraphUnit& subgraph, + tvm::Target tvm_target, + tvm::Target tvm_host_target, + NupharFuncInfo* func_info, + nuphar::OrtSubgraphAllocationInfo* partition_info) { + const auto& target_codegen = *context_.GetCodeGenHandle()->codegen_target; + std::string func_name = nuphar::GetPackedFuncName(subgraph, target_codegen); + tvm::BuildConfig config = CreateConfig(*subgraph.nodes.front(), + context_.GetCodeGenHandle()->allow_unaligned_buffers); + + // using "subgraph" for type and name for now + // TODO: change name + tvm::runtime::PackedFunc cached_func = + GetLoweredPackedFunc( + func_name, tvm_target, tvm_host_target, + config, "subgraph", "subgraph"); + + FillNupharFuncInfo(func_info, partition_info, subgraph, context_, tvm_target, cached_func, func_name); + + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.h b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.h new file mode 100644 index 0000000000000..1b4b4d5376f99 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/providers/nuphar/common/nuphar_subgraph.h" +#include "core/providers/nuphar/compiler/func_info.h" +#include "core/providers/nuphar/compiler/initializer_info.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/providers/nuphar/compiler/nuphar_handle.h" +#include "core/providers/nuphar/compiler/traverse_shape_infer.h" +#include "core/framework/op_kernel.h" +#include "core/graph/graph.h" +#include "gsl/gsl_util" + +#include +#include + +namespace onnxruntime { +namespace nuphar { + +class NupharCompiler { + public: + NupharCompiler(const Node& node, + const std::map& initializers, + std::unordered_map>& generated_initializers, + const NupharCodeGenHandle* handle); + + NupharCompiler(const nuphar::NupharSubgraphUnit& subgraph, + std::unordered_map>& generated_initializers, + const NupharCodeGenHandle* handle); + + // Build builds tvm IR and apply passes + Status Build(const nuphar::NupharSubgraphUnit& subgraph); + + // Lower lowers the built tvm IR to llvm ir and compiles it + Status Lower(const nuphar::NupharSubgraphUnit& subgraph, + tvm::Target tvm_target, + tvm::Target tvm_host_target, + NupharFuncInfo* ctx_func, + nuphar::OrtSubgraphAllocationInfo* partition_info); + + tvm::runtime::PackedFunc GetLoweredPackedFunc( + const std::string& func_name, + tvm::Target tvm_target, + tvm::Target tvm_host_target, + const tvm::BuildConfig& config, + const std::string& subgraph_type, + const std::string& subgraph_name); + + private: + size_t num_initializers_in_graph_inputs_; + + // BuildSubgraph builds tvm IR and apply passes for a subgraph + Status BuildSubgraph(const Node& node); + + NupharCodeGenCtx context_; + + tvm::Array tvm_args_; + tvm::Array tvm_outputs_; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_handle.h b/onnxruntime/core/providers/nuphar/compiler/nuphar_handle.h new file mode 100644 index 0000000000000..84be4555ba0f4 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_handle.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/codegen/common/common.h" +#include "core/codegen/common/handle.h" +#include "core/codegen/common/target_info.h" +#include "core/codegen/passes/weight_layout/weight_layout.h" +#include "core/framework/allocator.h" // TODO: get rid of this +#include "core/providers/nuphar/compiler/traverse_shape_infer.h" // TODO: get rid of this + +namespace onnxruntime { + +// forwarding +namespace tvm_codegen { +class TVMIRBuilder; +class TVMScheduleBuilder; +} // namespace tvm_codegen + +namespace nuphar { + +// TVM is a wrapper containing CodeGen related setting +// TODO: make this the Base +// TODO: create one for nuphar +struct NupharCodeGenHandle : codegen::CodeGenHandle { + std::shared_ptr op_ir_builder; // keep + std::shared_ptr schedule_builder; // keep + // maybe add a layout + tvm_codegen::WeightLayoutRegistry* layout_registry; + bool enable_per_node_parallelized; // TODO: change to config + + bool allow_unaligned_buffers; // move to another place + + AllocatorPtr allocator; // remove + std::shared_ptr shape_inference; // remove +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.cc new file mode 100644 index 0000000000000..4578134d359ec --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.cc @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/nuphar_op_ir_builder.h" + +#include "core/codegen/common/op_macro.h" +#include "core/codegen/mti/mti_tvm_utils.h" +#include "core/codegen/passes/op_ir_creator/all_ops.h" +#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h" +#include "core/codegen/passes/utils/ort_tvm_utils.h" +#include "core/common/common.h" +#include "core/providers/nuphar/compiler/initializer_info.h" +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +namespace onnxruntime { +namespace nuphar { + +// Declaration of GetOrCreateInitializer +// GetOrCreateInitializer create tvm::placeholder for a marshalled weight +// with correpsonding data layout transfomration for a weight, +// Note the weight is fed during build +static const tvm::Tensor& GetOrCreateInitializer(const std::string& name, + const Tensor* tensor, + bool is_sliced, + NupharCodeGenCtx& ctx_codegen); + +static const tvm::Tensor& GetOrCreateInitializer(const NodeArg* def, + const Tensor* tensor, + bool is_sliced, + NupharCodeGenCtx& ctx_codegen); + +// CreateInputPlaceholder create tvm input placeholder (tvm::Tensor) +// NOTE: here we assume axis 0 is sequence +// TODO: add support for sequence not axis 0 +static tvm::Tensor CreateInputPlaceholder(const tvm::Array& shape, + HalideIR::Type halide_type, + const std::string& name, + bool is_sliced) { + return tvm::placeholder(is_sliced && shape.size() > 1 ? tvm_codegen::SliceShapeFromDimension(shape, 1) : shape, halide_type, name); +} + +// CreateInput creats tvm::Tensor of corresponding ORT input +// Inputs are either initializer or regular input placeholder +static bool CreateInput( + const NodeArg* def, + tvm::Tensor& input, + bool initializer_only, + bool is_sliced, + NupharCodeGenCtx& ctx_codegen) { + const Tensor* initialized_tensor = ctx_codegen.GetOrtInitializerTensor(def->Name()); + if (nullptr == initialized_tensor && initializer_only) + return false; + + ORT_ENFORCE(def->Shape()); + if (nullptr != initialized_tensor) { + input = GetOrCreateInitializer(def, initialized_tensor, is_sliced, ctx_codegen); + } else { + // Handle inputs without initializer + std::string name = NormalizeNodeArgName(def); + MLDataType ONNXRUNTIME_data_type = DataTypeImpl::TypeFromProto(*def->TypeAsProto()); + DLDataType dtype = tvm_codegen::ToTvmDLDataType(ONNXRUNTIME_data_type); + HalideIR::Type halide_type((halideir_type_code_t)dtype.code, dtype.bits, dtype.lanes); + tvm::Array shape = ShapeToTvmArray(def, ctx_codegen); + + // Create InputPlaceholder + // Slice InputPlaceholder if it is asked for. + input = CreateInputPlaceholder(shape, halide_type, name, is_sliced); + } + return true; +} + +// GetOrCreateInitializer create tvm::placeholder for a marshalled weight +// with correpsonding data layout transfomration for a weight, +// Note the weight is fed during build +const tvm::Tensor& GetOrCreateInitializer(const std::string& name, + const Tensor* tensor, + bool is_sliced, + NupharCodeGenCtx& ctx_codegen) { + ORT_ENFORCE(ctx_codegen.IsInitializer(name)); + + auto layout_info = ctx_codegen.GetWeightLayoutInfo(name); + if (nullptr != layout_info) { + return layout_info->marshalled_tensor; + } + + auto ONNXRUNTIME_data_type = tensor->DataType(); + DLDataType dtype = tvm_codegen::ToTvmDLDataType(ONNXRUNTIME_data_type); + HalideIR::Type halide_type((halideir_type_code_t)dtype.code, dtype.bits, dtype.lanes); + std::string normalized_name = NormalizeCppName(name); + auto tvm_shape = tvm_codegen::ToTvmArray(tensor->Shape().GetDims()); + auto tvm_tensor = CreateInputPlaceholder(tvm_shape, halide_type, normalized_name, is_sliced); + // create the layout info + ctx_codegen.CreateWeightLayoutInfo(name, tvm_tensor); + return ctx_codegen.GetWeightLayoutInfo(name)->marshalled_tensor; +} + +const tvm::Tensor& GetOrCreateInitializer(const NodeArg* def, + const Tensor* tensor, + bool is_sliced, + NupharCodeGenCtx& ctx_codegen) { + return GetOrCreateInitializer(def->Name(), tensor, is_sliced, ctx_codegen); +} + +// CreateOutputs constructs tvm::Tensor with corresponding computation +static Status CreateOutputs(const Node* node, + const tvm::Array& inputs, + tvm::Array& outputs, + NupharCodeGenCtx& ctx_codegen) { + ORT_RETURN_IF_ERROR(ctx_codegen.GetCodeGenHandle() + ->op_ir_builder + ->Evaluate(inputs, *node, ctx_codegen, outputs)); + + // Collect constructed tvm::Node to onnxruntime::Node mapping + // Both states and outputs + for (const auto& t : outputs) { + ctx_codegen.RecordTensorToNode(t, node); + } + + return Status::OK(); +} + +// CreateTVMIR is the entry function for building TVM IR +// It will call TVMIRBuilder (in CreateOutputs) from CodeGenContext +Status CreateTVMIR( + const GraphViewer& graph, + NupharCodeGenCtx& ctx_codegen, + bool use_placeholder_for_input) { + TVMTensorCtx& ctx_tensor = ctx_codegen.GetTVMTensorCtx(); + + if (use_placeholder_for_input) { + // build graph inputs + const auto& graph_inputs = graph.GetInputs(); + for (size_t i = 0; i < graph_inputs.size(); ++i) { + tvm::Tensor value; + if (CreateInput(graph_inputs[i], value, + /*initializer_only*/ false, /*is_sliced*/ false, + ctx_codegen)) { + ctx_tensor.inputs.emplace(graph_inputs[i]->Name(), std::move(value)); + } + } + } + + for (const auto& node : graph.Nodes()) { + // initializers + node.ForEachWithIndex( + node.InputDefs(), + [&ctx_codegen, &ctx_tensor](const NodeArg& def, size_t) { + tvm::Tensor value; + if (CreateInput(&def, value, /*initializer_only*/ true, /*is_sliced*/ false, + ctx_codegen)) { + ctx_tensor.inputs.emplace(def.Name(), std::move(value)); + } + return Status::OK(); + }); + } + + // iterate through the graph and create op (outputs) + for (auto node_index : graph.GetNodesInTopologicalOrder()) { + const auto& node = *graph.GetNode(node_index); + tvm::Array inputs; + for (const NodeArg* def : node.InputDefs()) { + tvm::Tensor input; + if (def->Exists()) { + bool exist = ctx_tensor.Lookup(def, input); + if (!exist) { + tvm::Tensor value; + if (CreateInput(def, value, + /*initializer_only*/ false, /*is_sliced*/ false, + ctx_codegen)) { + ctx_tensor.inputs.emplace(def->Name(), std::move(value)); + } + input = ctx_tensor.Lookup(def); + } + } + inputs.push_back(input); + } + + auto subgraph = GetSubgraph(node); + if (nullptr != subgraph) { + // unboxing + GraphViewer subgraph_viewer(*subgraph); + ORT_RETURN_IF_ERROR(CreateTVMIR(subgraph_viewer, ctx_codegen, /*use_placeholder_for_input*/ false)); + } else { + tvm::Array op_outputs; + ORT_RETURN_IF_ERROR(CreateOutputs(&node, inputs, op_outputs, ctx_codegen)); + ctx_tensor.ops.emplace(&node, std::move(op_outputs)); + + // input_from_ + node.ForEachWithIndex( + node.OutputDefs(), + [&node, &ctx_tensor](const NodeArg& def, size_t index) { + ORT_ENFORCE(ctx_tensor.input_from.count(def.Name()) == 0); + ctx_tensor.input_from.emplace(def.Name(), std::make_pair(&node, index)); + return Status::OK(); + }); + } + } + + return Status::OK(); +} + +// CreateTVMIR is the entry function for building TVM IR +// It will call TVMIRBuilder (in CreateOutputs) from CodeGenContext +Status CreateTVMIR( + const Node& node, + NupharCodeGenCtx& ctx_codegen) { + // wrapper + TVMTensorCtx& ctx_tensor = ctx_codegen.GetTVMTensorCtx(); + bool has_loop = HasLoop(node); + + // create real Inputs + node.ForEachWithIndex( + node.InputDefs(), + [&has_loop, &ctx_codegen, &ctx_tensor](const NodeArg& def, size_t) { + tvm::Tensor value; + if (CreateInput(&def, value, /*initializer_only*/ false, /*is_sliced*/ has_loop, + ctx_codegen)) { + ctx_tensor.inputs.emplace(def.Name(), std::move(value)); + } + return Status::OK(); + }); + + // input_from_ + node.ForEachWithIndex( + node.OutputDefs(), + [&node, &ctx_tensor](const NodeArg& def, size_t index) { + ctx_tensor.input_from.emplace(def.Name(), std::make_pair(&node, index)); + return Status::OK(); + }); + + tvm::Array inputs; + for (const NodeArg* def : node.InputDefs()) { + inputs.push_back(def->Exists() ? ctx_tensor.Lookup(def) : tvm::Tensor()); + } + + // create ops (outputs) + tvm::Array op_outputs; + ORT_RETURN_IF_ERROR(CreateOutputs(&node, inputs, op_outputs, ctx_codegen)); + ctx_tensor.ops.emplace(&node, std::move(op_outputs)); + + return Status::OK(); +} + +// CreateTVMIR is the entry function for building TVM IR +// It will call TVMIRBuilder (in CreateOutputs) from CodeGenContext +Status CreateTVMIR( + const nuphar::NupharSubgraphUnit& subgraph, + NupharCodeGenCtx& ctx_codegen) { + //////////////////////////////////////// + // handle a special case for a single node + //////////////////////////////////////// + if (subgraph.IsSingleNode()) { + const Node* node = subgraph.nodes.front(); + + const Graph* onnx_graph = GetSubgraph(*node); + + if (nullptr != onnx_graph) { + return CreateTVMIR(GraphViewer(*onnx_graph), ctx_codegen, true); + } + return CreateTVMIR(*node, ctx_codegen); + } + + ////////////////////////////// + // handle a generic subgraph below + ////////////////////////////// + TVMTensorCtx& ctx_tensor = ctx_codegen.GetTVMTensorCtx(); + + // build subgraph inputs + for (const NodeArg* def : subgraph.inputs) { + tvm::Tensor value; + + if (CreateInput(def, value, /*initializer_only*/ false, /*is_sliced*/ false, + ctx_codegen)) { + ctx_tensor.inputs.emplace(def->Name(), std::move(value)); + } + } + + // build subgraph initializers + for (auto& p : subgraph.initializers) { + tvm::Tensor value = GetOrCreateInitializer(p.first, p.second, false, ctx_codegen); + ctx_tensor.inputs.emplace(p.first, std::move(value)); + } + + // iterate through the subgraph nodes and create op (outputs) + for (auto& node : subgraph.nodes) { + tvm::Array inputs; + + // collects local inputs + for (const NodeArg* def : node->InputDefs()) { + inputs.push_back(def->Exists() ? ctx_tensor.Lookup(def) : tvm::Tensor()); + } + + tvm::Array op_outputs; + ORT_RETURN_IF_ERROR(CreateOutputs(node, inputs, op_outputs, ctx_codegen)); + ctx_tensor.ops.emplace(node, std::move(op_outputs)); + + // input_from_ + node->ForEachWithIndex( + node->OutputDefs(), + [&node, &ctx_tensor](const NodeArg& def, size_t index) { + ORT_ENFORCE(ctx_tensor.input_from.count(def.Name()) == 0); + ctx_tensor.input_from.emplace(def.Name(), std::make_pair(node, index)); + return Status::OK(); + }); + } + + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.h b/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.h new file mode 100644 index 0000000000000..532e917b5d8cc --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" + +#include "core/providers/nuphar/common/nuphar_subgraph.h" + +namespace onnxruntime { +namespace nuphar { + +// CreateTVMIR function traverses a GraphViewer +// and builds tvm ir (and store them in CodeGenContext) +// based on corresponding ORT ir +Status CreateTVMIR(const GraphViewer& graph, + NupharCodeGenCtx& ctx_codegen, + bool use_placeholder_for_input); + +// CreateTVMIR function traverses a single node +// and builds tvm ir (and store them in CodeGenContext) +// based on corresponding ORT ir +Status CreateTVMIR(const Node& node, + NupharCodeGenCtx& ctx_codegen); + +// CreateTVMIR function traverses a NupharSubgraphUnit +// and builds tvm ir (and store them in CodeGenContext) +// based on corresponding ORT ir +Status CreateTVMIR(const nuphar::NupharSubgraphUnit& subgraph, + NupharCodeGenCtx& ctx_codegen); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.cc new file mode 100644 index 0000000000000..2755f0c01aed1 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/nuphar_schedule_builder.h" + +#include "core/codegen/common/settings.h" +#include "core/codegen/passes/scheduler/schedule_utils.h" +#include "core/codegen/passes/scheduler/tvm_schedule_builder.h" + +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" + +// TODO change name space +namespace onnxruntime { +namespace nuphar { + +// Traverse iterates a tvm::Tensor and itself dependencies +// and builds schedule (in ScheduleContext) +// based on corresponding ORT ir and TVM ir +static void Traverse(const tvm::Tensor& tensor, + const Node* node, + NupharCodeGenCtx& ctx_codegen, + tvm_codegen::ScheduleContext& ctx_schedule) { + // no need to traverse on nodes already marked as closured + if (ctx_schedule.scheduled_tensors.count(tensor->op.get()) > 0) { + if (ctx_schedule.scheduled_tensors[tensor->op.get()] == tvm_codegen::ScheduleType::ScheduleClosure) { + return; + } + } + + ctx_codegen.GetCodeGenHandle()->schedule_builder->Evaluate(tensor, node, ctx_codegen, ctx_schedule); + + // for real ouput + bool is_real_output = nullptr != node && + Promote(ctx_codegen.GetGraphStats())->IsOutputNode(node); + + if (is_real_output) { + // TODO change it to the value from Target + int64_t natural_vector_size = 16; + + TryVectorization(tensor, natural_vector_size, ctx_schedule); // to x86 + InsertRootScheduleAndClosure(tensor, ctx_schedule); + } + + // Traverse tensor's children + for (auto& t : tensor->op->InputTensors()) { + // check whether it is a tensor having inputs + if (t->op->InputTensors().size() > 0) { + auto current_node = ctx_codegen.FindNode(t); + Traverse(t, current_node, ctx_codegen, ctx_schedule); + } + } +} + +tvm::Schedule CreateSchedule(const tvm::Array& outs, + NupharCodeGenCtx& ctx_codegen) { + // Create scheudule object + tvm::Array out_ops; + for (auto& t : outs) { + out_ops.push_back(t->op); + } + + if (codegen::CodeGenSettings::Instance().HasOption(codegen::CodeGenSettings::kCodeGenDumpSchedule)) + ctx_codegen.GetCodeGenHandle()->schedule_builder->DumpAllSchedulers(); + + tvm_codegen::ScheduleContext ctx_schedule(out_ops); + + // Schedule all outputs + for (const auto& t : outs) { + const Node* node = ctx_codegen.FindNode(t); + Traverse(t, node, ctx_codegen, ctx_schedule); + } + + return ctx_schedule.schedule; +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.h b/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.h new file mode 100644 index 0000000000000..de4631b154afe --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_schedule_builder.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" + +// TODO change name space +namespace onnxruntime { +namespace nuphar { + +// Traverse iterates tvm::Array a single node +// and builds the whole schedule (in CodeGenContext) +tvm::Schedule CreateSchedule(const tvm::Array& outs, + NupharCodeGenCtx& ctx_codegen); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.cc b/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.cc new file mode 100644 index 0000000000000..4a20343113a07 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/traverse_shape_infer.h" + +#include "core/codegen/common/common.h" +#include "core/common/common.h" +#include "core/framework/tensorprotoutils.h" + +// TODO retire this file + +namespace onnxruntime { +namespace nuphar { + +// local shape infernece function for input +static bool CreateInput(const NodeArg* def, + const GraphViewer& graph, + ShapeExpr& input, + bool initializer_only) { + if (initializer_only && graph.GetAllInitializedTensors().count(def->Name()) == 0) + return false; + + auto def_shape = def->Shape(); + if (!def_shape) + return false; + + int rank = def_shape->dim_size(); + input = ShapeExpr(rank); + for (int i = 0; i < rank; ++i) { + const auto& dim = def_shape->dim()[i]; + if (dim.has_dim_value()) + input[i] = DimExpr(dim.dim_value()); + else if (dim.has_dim_param()) + input[i] = DimExpr(dim.dim_param()); + else { + input[i] = DimExpr(NormalizeNodeArgName(def) + "_dim" + std::to_string(i)); + } + } + return true; +} + +// local shape infernece function for output +static Status CreateOutputs( + const Node* node, + const std::vector& inputs, + std::vector& outputs) { + outputs.resize(node->OutputDefs().size()); + node->ForEachWithIndex( + node->OutputDefs(), + [&](const NodeArg& def, size_t index) { + auto shape_proto = def.Shape(); + if (shape_proto) { + TensorShape shape{utils::GetTensorShapeFromTensorShapeProto(*shape_proto)}; + ShapeExpr output_shape(shape.NumDimensions()); + for (int d = 0; d < gsl::narrow(shape.NumDimensions()); ++d) { + if (shape[d] > 0) { + output_shape[d] = DimExpr(shape[d]); + } else { + ORT_RETURN_IF_NOT(shape_proto->dim_size() > d && shape_proto->dim(d).has_dim_param()); + output_shape[d] = DimExpr(shape_proto->dim(d).dim_param()); + } + } + outputs[index] = output_shape; + } + return Status::OK(); + }); + return Status::OK(); +} + +// The main function for shape infernece +Status ShapeInference( + const GraphViewer& graph, + ShapeExprContext& context) { + // build graph inputs + const auto& graph_inputs = graph.GetInputs(); + for (size_t i = 0; i < graph_inputs.size(); ++i) { + ShapeExpr value; + if (CreateInput(graph_inputs[i], graph, value, /*initializer_only*/ false)) { + context.inputs.emplace(graph_inputs[i]->Name(), std::move(value)); + } + } + + // perform shape inference using the topological order from ORT + for (const NodeIndex& node_index : graph.GetNodesInTopologicalOrder()) { + const Node& node = *graph.GetNode(node_index); + // initializers + node.ForEachWithIndex( + node.InputDefs(), + [&graph, &context](const NodeArg& def, size_t) { + ShapeExpr value; + if (CreateInput(&def, graph, value, /*initializer_only*/ true)) { + context.inputs.emplace(def.Name(), std::move(value)); + } + return Status::OK(); + }); + + // handle subgraph + const Graph* subgraph = GetSubgraph(node); + if (nullptr != subgraph) { + GraphViewer subgraph_viewer(*subgraph); + ShapeInference(subgraph_viewer, context); + } + + // collect inputs before creating outputs + std::vector inputs; + for (const NodeArg* def : node.InputDefs()) { + inputs.push_back(def->Exists() ? context.Lookup(def) : nullptr); + } + + // create outputs + std::vector op_outputs; + ORT_RETURN_IF_ERROR(CreateOutputs(&node, inputs, op_outputs)); + context.ops.emplace(&node, std::move(op_outputs)); + + // recall input_from_ + node.ForEachWithIndex( + node.OutputDefs(), + [&node, &context](const NodeArg& def, size_t index) { + context.input_from.emplace(def.Name(), std::make_pair(&node, index)); + return Status::OK(); + }); + } + + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.h b/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.h new file mode 100644 index 0000000000000..deaa5777a3c66 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/traverse_shape_infer.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/nuphar/common/analysis/shape_expr.h" +#include "core/common/common.h" +#include "core/framework/tensor.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { +namespace nuphar { + +// A collection of ShapeExpr +struct ShapeExprContext { + std::map inputs; + std::map> ops; + std::map> input_from; + + const ShapeExpr* Lookup(const NodeArg* def) const { + const std::string& def_name = def->Name(); + auto iter = inputs.find(def_name); + if (iter != inputs.end()) + return &(iter->second); + + auto iter_out_index = input_from.find(def_name); + + // OK if shape inference is incomplete + // This is for some per-node unit test where NodeArg does not even have shape ranks + // We ignore the shape inference in ToCapacity computation in per-node unit tests + if (iter_out_index == input_from.end()) + return nullptr; + + const Node* from_node = iter_out_index->second.first; + size_t index = iter_out_index->second.second; + auto iter_op = ops.find(from_node); + ORT_ENFORCE(iter_op != ops.end()); + return &(iter_op->second[index]); + } +}; + +// Traverse function traverses a GraphViewer, +// performs shape infernce, +// and builds ShapeExpr in ShapeExprContext +Status ShapeInference(const GraphViewer& graph, + ShapeExprContext& context); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h new file mode 100644 index 0000000000000..2d2c55b17c169 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/passes/utils/codegen_context.h" +#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" + +namespace onnxruntime { +namespace nuphar { + +// Declare a TVM IR builder based on the ORT OP type +// with postfix NupharTVMX86 +#define DECLARE_NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(OP) \ + DECLARE_OP_IR_CREATOR_CLASS_EX(OP, NupharTVM, X86) + +// Return a TVM IR builder class name such as OP type +// with postfix NupharTVMX86 +#define NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(OP) \ + OP_IR_CREATOR_CLASS_EX(OP, NupharTVM, X86) + +#define NUPHAR_TVM_X86_OP_IR_CREATOR_STRING(OP) \ + STRINGIZE(NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(OP)) + +#define LIST_X86_UNARY_OPS() \ + UNARY_OP(Erf) \ + UNARY_OP(Exp) \ + UNARY_OP(Log) \ + UNARY_OP(ParametricSoftplus) \ + UNARY_OP(ScaledTanh) \ + UNARY_OP(Selu) \ + UNARY_OP(Sigmoid) \ + UNARY_OP(Softplus) \ + UNARY_OP(Tanh) + +#define LIST_REDUCE_V_OPS() \ + REDUCE_V_OP(ReduceMax) \ + REDUCE_V_OP(ReduceMin) \ + REDUCE_V_OP(ReduceSum) + +#define LIST_ALL_X86_OPS() \ + LIST_REDUCE_V_OPS() \ + LIST_X86_UNARY_OPS() \ + ADD_OP_ITEM(Gemm) \ + ADD_OP_ITEM(LogSoftmax) \ + ADD_OP_ITEM(MatMul) \ + ADD_OP_ITEM(MatMulInteger) \ + ADD_OP_ITEM(MatMulInteger16) \ + ADD_OP_ITEM(Slice) \ + ADD_OP_ITEM(Softmax) \ + ADD_OP_ITEM(Tile) + +// Define all OPs for NupharTVMX86 +#define ADD_OP_ITEM(OP) DECLARE_NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(OP) +#define REDUCE_V_OP(OP) ADD_OP_ITEM(OP) +#define UNARY_OP(OP) ADD_OP_ITEM(OP) + +LIST_ALL_X86_OPS() + +#undef ADD_OP_ITEM +#undef REDUCE_V_OP +#undef UNARY_OP + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/gemm.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/gemm.cc new file mode 100644 index 0000000000000..5ac2adf738017 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/gemm.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/codegen/mti/math/binary_ops.h" +#include "core/codegen/mti/math/gemm.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/providers/nuphar/mti_x86/math/matmul_ops.h" + +namespace onnxruntime { +namespace nuphar { + +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Gemm)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + + tvm::Tensor Y; + auto& A = inputs[0]; + auto& B = inputs[1]; + auto& C = inputs[2]; + int64_t trans_a, trans_b; + float alpha, beta; + ORT_RETURN_IF_ERROR(info.GetAttr("transA", &trans_a)); + ORT_RETURN_IF_ERROR(info.GetAttr("transB", &trans_b)); + ORT_RETURN_IF_ERROR(info.GetAttr("alpha", &alpha)); + ORT_RETURN_IF_ERROR(info.GetAttr("beta", &beta)); + + // use native sgemm for floating point + if (A->dtype == HalideIR::Float(32) && + B->dtype == HalideIR::Float(32) && + MatMulExternCpu(A, B, Y, !!trans_a, !!trans_b, node.Name() + "_gemm")) { + if (beta != 0) { + tvm::Tensor beta_bias = (beta == 1) ? C : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), beta), C); + Y = tvm_codegen::Add((alpha == 1) ? Y : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), alpha), Y), beta_bias, node.Name() + "_add_bias"); + } + outputs.push_back(Y); + return Status::OK(); + } + + // fallback to default MTI ops + Y = tvm_codegen::Gemm(A, B, C, trans_a, trans_b, alpha, beta, node.Name()); + outputs.push_back(Y); + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/logsoftmax.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/logsoftmax.cc new file mode 100644 index 0000000000000..aef32e3d3c81d --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/logsoftmax.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/providers/nuphar/mti_x86/math/logsoftmax.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace nuphar { + +// Evaluate of LogSoftmax OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(LogSoftmax)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext&, + tvm::Array& outputs) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + + int64_t axis_i64; + ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); + + axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); + tvm::Tensor Y = nuphar::LogSoftmax(inputs[0], axis_i64); + outputs.push_back(Y); + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/matmul.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/matmul.cc new file mode 100644 index 0000000000000..e81ef497c50a8 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/matmul.cc @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/providers/nuphar/mti_x86/math/matmul_ops.h" +#include "core/codegen/mti/mti_tvm_utils.h" +#include "core/codegen/passes/weight_layout/transpose_2d.h" +#include "core/codegen/passes/weight_layout/vertical_stripes_2d.h" +#include "core/providers/nuphar/compiler/x86/x86_target_info.h" + +#include + +namespace onnxruntime { +namespace nuphar { + +// TODO: remove tvm core function + +// local helper functions + +static bool MatMul_weights2D( + ONNX_NAMESPACE::TensorProto_DataType proto_type, + const tvm::Tensor& A, + const tvm::Tensor& B, + const std::string& initializer_name, + NupharCodeGenCtx& ctx_codegen, + tvm::Tensor& Y, + const std::string& name = "matmul_weights2d") { + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + + // optimizations for B being 2D weights + + // The 2D weight is marshalled with stripe_width. + // This should be 2x nature vector width + int stripe_width = 8; + int block_size = 32; + + onnxruntime::CodeGenTargetX86* target = + dynamic_cast(ctx_codegen.GetCodeGenHandle()->codegen_target); + if (nullptr != target) { + stripe_width = 2 * target->NaturalVectorWidth(B->dtype.bits()); + } + + // align A, B to multiple of block size + const auto& A_shape = A->shape; + tvm::Expr A0_size = tvm_codegen::SizeToDimension(A_shape, -1); + auto A0_roundup = tvm_codegen::RoundUp(A0_size, block_size); + tvm::Expr A1_size = tvm_codegen::SizeFromDimension(A_shape, -1); + auto A1_roundup = tvm_codegen::RoundUp(A1_size, block_size); + bool A0_need_pad = !tvm::ir::Equal(A0_roundup, A0_size); + bool A1_need_pad = !tvm::ir::Equal(A1_roundup, A1_size); + + const auto& B_shape = B->shape; + tvm::Expr B0_size = tvm_codegen::SizeToDimension(B_shape, 1); + auto B0_roundup = tvm_codegen::RoundUp(B0_size, block_size); + tvm::Expr B1_size = tvm_codegen::SizeFromDimension(B_shape, 1); + auto B1_roundup = tvm_codegen::RoundUp(B1_size, block_size); + bool B1_need_pad = !tvm::ir::Equal(B1_roundup, B1_size); + + ORT_ENFORCE(tvm::ir::Equal(A1_roundup, B0_roundup)); + + // Currently only support padding in B1, as it's free with memory marshalling + if (A0_need_pad || A1_need_pad || B1_need_pad) + return false; + + auto layout_key = tvm_codegen::WeightLayoutVerticalStripe2D::GetKey(proto_type, stripe_width); + auto B_unmarshalled = ctx_nuphar->ApplyWeightLayout(layout_key, initializer_name, B, false); + + ORT_ENFORCE(B_unmarshalled->op.as()); + + tvm::Array Y_shape; + for (size_t d = 0; d < A->shape.size() - 1; ++d) + Y_shape.push_back(A->shape[d]); + Y_shape.push_back(B->shape[1]); + + auto k = tvm::reduce_axis(tvm::Range(0, A1_size), "k"); + Y = tvm::compute( + Y_shape, + [&](const tvm::Array& idx) { + tvm::Array A_indices; + for (size_t d = 0; d < idx.size() - 1; ++d) + A_indices.push_back(idx[d]); + A_indices.push_back(k); + return tvm::sum(A(A_indices) * B_unmarshalled(k, idx[idx.size() - 1]), {k}); + }, + name); + + return true; +} + +static bool MatMulF32ExternCpuEx( + ONNX_NAMESPACE::TensorProto_DataType proto_type, + NupharCodeGenCtx& ctx_nuphar, + const tvm::Tensor& A, + const tvm::Tensor& B, + tvm::Tensor& Y, + const std::string& B_initializer_name = "", + bool trans_a = false, + bool trans_b = false, + const std::string& name = "matmul_extern_cpu_ex") { + // transpose weights if not already + tvm::Tensor actual_B = B; + + if (ctx_nuphar.IsInitializer(B_initializer_name) && !trans_b) { + auto layout_key = tvm_codegen::WeightLayoutTranspose2D::GetKey(proto_type); + actual_B = ctx_nuphar.ApplyWeightLayout(layout_key, B_initializer_name, B, true); + trans_b = true; + } + + return nuphar::MatMulExternCpu(A, actual_B, Y, trans_a, trans_b, name); +} + +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(MatMul)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + + auto proto_type = TensorProtoDataType(node.InputDefs()[1]); + + tvm::Tensor Y; + auto& A = inputs[0]; + auto& B = inputs[1]; + const std::string& input_1_name = node.InputDefs()[1]->Name(); + + if (A->dtype == HalideIR::Float(32) && + B->dtype == HalideIR::Float(32) && + MatMulF32ExternCpuEx(proto_type, *ctx_nuphar, A, B, Y, input_1_name)) { + outputs.push_back(Y); + return Status::OK(); + } + + if (ShapeRank(node.InputDefs()[1]) == 2 && ctx_nuphar->IsInitializer(input_1_name)) { + if (MatMul_weights2D(proto_type, A, B, input_1_name, *ctx_nuphar, Y)) { + outputs.push_back(Y); + return Status::OK(); + } + } + + Y = nuphar::MatMul(A, B, node.Name()); + outputs.push_back(Y); + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/quantize/matmul_integer.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/quantize/matmul_integer.cc new file mode 100644 index 0000000000000..1fbf3516c46f1 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/quantize/matmul_integer.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/codegen/mti/math/binary_ops.h" +#include "core/codegen/mti/math/matmul_ops.h" +#include "core/codegen/mti/mti_tvm_utils.h" +#include "core/codegen/mti/tensor/cast_ops.h" +#include "core/codegen/mti/tensor/reshape_ops.h" +#include "core/codegen/mti/tensor/transpose.h" +#include "core/codegen/passes/weight_layout/transpose_2d.h" +#include "core/common/cpuid_info.h" // TODO: refactor to control through config +#include "core/providers/nuphar/common/nuphar_settings.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/providers/nuphar/mti_x86/quantize/imatmul_extern.h" +#include "core/providers/nuphar/mti_x86/quantize/imatmul16_extern.h" + +namespace onnxruntime { +namespace nuphar { + +// Evaluate of MatMulInteger or MatMulInteger16 +static Status EvaluateMatMulInteger( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + + const auto& A = inputs[0]; + const auto& B = inputs[1]; + auto& name = node.Name(); + + if (B->shape.size() == 2) { + const int64_t* p_input_dim = tvm::as_const_int(B->shape[0]); + const int64_t* p_embed_dim = tvm::as_const_int(B->shape[1]); + + if (p_input_dim != nullptr && p_embed_dim != nullptr) { + int64_t input_dim = *p_input_dim; + int64_t embed_dim = *p_embed_dim; + + bool is16bitSymm = (A->dtype == HalideIR::type_of() && + B->dtype == HalideIR::type_of()); + bool is8bitAsymm = (A->dtype == HalideIR::type_of() && + B->dtype == HalideIR::type_of()); + + if (is16bitSymm || is8bitAsymm) { + auto A_rank = gsl::narrow_cast(A->shape.size()); + + tvm::Array output_shape; + for (int i = 0; i < A_rank - 1; ++i) { + output_shape.push_back(A->shape[i]); + } + output_shape.push_back(tvm::Expr(gsl::narrow_cast(embed_dim))); + + tvm::Tensor B_marshalled; + auto B_NodeArg = node.InputDefs()[1]; + const std::string& B_name = B_NodeArg->Name(); + + if (ctx_nuphar->IsInitializer(B_name)) { + auto layout_key = tvm_codegen::WeightLayoutTranspose2D::GetKey(TensorProtoDataType(B_NodeArg)); + B_marshalled = ctx_nuphar->ApplyWeightLayout(layout_key, B_name, B, true); + } else { + B_marshalled = tvm_codegen::Transpose(B, {1, 0}); + } + + // TODO: add reserved_bits attribute + bool use_AVX2; + const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); + if (settings.HasOption(kNupharIMatMulForceMkl)) { + use_AVX2 = false; + } else { + use_AVX2 = CPUIDInfo::GetCPUIDInfo().HasAVX2(); + } + auto output_tensor = + is16bitSymm ? use_AVX2 ? IMatMul16ExternAVX2(B_marshalled, A, + output_shape, input_dim, embed_dim, + name + "_IMatMul16ExternAVX2") + : IMatMul16ExternMKL(B_marshalled, A, + output_shape, input_dim, embed_dim, + name + "_IMatMul16ExternMKL") + : use_AVX2 ? IMatMulExternAVX2(B_marshalled, A, + output_shape, input_dim, embed_dim, + name + "_IMatMulExternAVX2") + : IMatMulExternMKL(B_marshalled, A, + output_shape, input_dim, embed_dim, + name + "_IMatMulExternMKL"); + + outputs.push_back(output_tensor); + return Status::OK(); + } + } + } + // slow path, cast to int32 for now + // Support skipped trailing inputs + auto A_Int32 = (node.InputDefs().size() >= 3 && node.InputDefs()[2]->Exists()) + ? tvm_codegen::Sub(tvm_codegen::Cast(A, HalideIR::Int(32)), tvm_codegen::Cast(inputs[2], HalideIR::Int(32))) + : tvm_codegen::Cast(A, HalideIR::Int(32)); + auto B_Int32 = (node.InputDefs().size() >= 4 && node.InputDefs()[3]->Exists()) + ? tvm_codegen::Sub(tvm_codegen::Cast(B, HalideIR::Int(32)), tvm_codegen::Cast(inputs[3], HalideIR::Int(32))) + : tvm_codegen::Cast(B, HalideIR::Int(32)); + tvm::Tensor Y = tvm_codegen::MatMul(A_Int32, B_Int32, name); + outputs.push_back(Y); + return Status::OK(); +} + +// Evaluate of MatMulInteger OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(MatMulInteger)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + return EvaluateMatMulInteger(inputs, node, ctx_codegen, outputs); +} + +// Evaluate of MatMulInteger16 OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(MatMulInteger16)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + return EvaluateMatMulInteger(inputs, node, ctx_codegen, outputs); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/reduce_ops.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/reduce_ops.cc new file mode 100644 index 0000000000000..5dee08e0b51ca --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/reduce_ops.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/providers/nuphar/mti_x86/math/reduce_ops.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" + +#include // for sort + +namespace onnxruntime { +namespace nuphar { + +using ReduceVFunc = tvm::Tensor (*)(const tvm::Tensor& X, + const std::vector& axes, + bool keep_dims, + int32_t vector_size, + bool last_dim_aligned, + int32_t fuse_dim, + const std::string& name); + +// This function gives a proper vector width and fuse dim for reduce +// It avoids vector_width larger than shape +// Fuse dim implies mulitple reduce axis could be fused together to form a longer vector_width +// It can avoid too small vector_width +static std::tuple VectorWidthAndFuseDimForReduce(int natural_width, + std::vector axes, + const NodeArg* def) { + int64_t rank = ShapeRank(def); + if (rank == 0) { + return std::make_tuple(1, 0); + } + + int tail_size = 1; + + // reduce all + if (axes.size() == 0) { + for (int i = gsl::narrow_cast(rank) - 1; i >= 0; --i) { + if (ShapeHasValue(def, i)) { + tail_size *= gsl::narrow_cast(ShapeValue(def, i)); + } else { + if (i > 0) + return std::make_tuple(tail_size, i - 1); + else + return std::make_tuple(natural_width, 0); + } + + if (tail_size >= natural_width) { + return std::make_tuple(natural_width, i); + } + } + + return std::make_tuple(tail_size, 0); + } + + //reduce last + int j = axes.size() - 1; + if (axes.back() == (rank - 1)) { + for (int i = gsl::narrow_cast(rank) - 1; i >= 0; --i) { + if (ShapeHasValue(def, i) && axes[j] == gsl::narrow_cast(i)) { + tail_size *= gsl::narrow_cast(ShapeValue(def, i)); + if (j > 0) + --j; + } else { + if (i > 0) { + return std::make_tuple(tail_size, i - 1); + } else { + return std::make_tuple(natural_width, 0); + } + } + + if (tail_size >= natural_width) { + return std::make_tuple(natural_width, i); + } + } + + return std::make_tuple(tail_size, 0); + } + + // reduce other + for (int i = gsl::narrow_cast(rank) - 1; i >= 0; --i) { + if (ShapeHasValue(def, i) && axes[j] != gsl::narrow_cast(i)) { + tail_size *= gsl::narrow_cast(ShapeValue(def, i)); + if (j > 0) + --j; + } else { + if (i > 0) + return std::make_tuple(tail_size, i - 1); + else + return std::make_tuple(natural_width, 0); + } + + if (tail_size >= natural_width) { + return std::make_tuple(natural_width, i); + } + } + + return std::make_tuple(tail_size, 0); +} + +class FuncReduceV { + public: + FuncReduceV(const Node& node, + ReduceVFunc func, + std::function natural_vector, + const NodeArg* def, + const std::string& name) : def_(def) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + axes_ = info.GetAttrsOrDefault("axes"); + std::sort(axes_.begin(), axes_.end()); //ReduceV requires sorted axes + int64_t keepdims_i = 1; + ORT_ENFORCE(info.GetAttr("keepdims", &keepdims_i).IsOK()); + keep_dims_ = (keepdims_i == 1); + func_ = func; + name_ = node.Name() + "_" + name; + natural_vector_ = natural_vector; + } + + tvm::Tensor operator()(const tvm::Tensor& X) const { + std::vector axes; + for (auto i : axes_) { + axes.push_back(HandleNegativeAxis(i, gsl::narrow_cast(X->shape.size()))); + } + + auto p = VectorWidthAndFuseDimForReduce(natural_vector_(X->dtype.bits()), axes, def_); + int vector_width = std::get<0>(p); + int fuse_dim = std::get<1>(p); + + bool last_dim_aligned = false; + const int64_t* p_last_dim_size = tvm::as_const_int(X->shape[X->shape.size() - 1]); + + if (p_last_dim_size != nullptr) { + last_dim_aligned = (*p_last_dim_size) % vector_width == 0; + } + + return func_(X, axes, keep_dims_, vector_width, last_dim_aligned, fuse_dim, name_); + } + + private: + std::vector axes_; + bool keep_dims_; + ReduceVFunc func_; + std::string name_; + std::function natural_vector_; + const NodeArg* def_; +}; + +#define REDUCE_V_OP(name) \ + Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(name)::Evaluate( \ + const tvm::Array& inputs, \ + const Node& node, \ + tvm_codegen::CodeGenContext& ctx_codegen, \ + tvm::Array& outputs) { \ + auto natural_vector = [&](int bits) { \ + return ctx_codegen.GetCodeGenHandle()->codegen_target->NaturalVectorWidth(bits); \ + }; \ + tvm::Tensor Y = FuncReduceV(node, &nuphar::name, natural_vector, node.InputDefs()[0], #name)(inputs[0]); \ + outputs.push_back(Y); \ + return Status::OK(); \ + } + +LIST_REDUCE_V_OPS() + +#undef REDUCE_V_OP + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/softmax.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/softmax.cc new file mode 100644 index 0000000000000..28efe2cf1c0f8 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/softmax.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/providers/nuphar/mti_x86/math/softmax.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace nuphar { + +// Evaluate of Softmax OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Softmax)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext&, + tvm::Array& outputs) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + + int64_t axis_i64; + ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); + + axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); + tvm::Tensor Y = Softmax(inputs[0], axis_i64); + outputs.push_back(Y); + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/unary_ops.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/unary_ops.cc new file mode 100644 index 0000000000000..79cf4ecbd38cd --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/math/unary_ops.cc @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/codegen/common/op_macro.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/nuphar/mti_x86/math/unary_ops.h" + +namespace onnxruntime { +namespace nuphar { + +// helper class for unary_ops with alpha +class FuncWithAlpha { + public: + FuncWithAlpha(const Node& node) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper attrs(&ctx); + ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); + } + + protected: + float alpha_; +}; + +// helper class for unary_ops with alpha and beta +class FuncWithAlphaBeta { + public: + FuncWithAlphaBeta(const Node& node) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper attrs(&ctx); + ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); + ORT_ENFORCE(attrs.GetAttr("beta", &beta_).IsOK()); + } + + protected: + float alpha_; + float beta_; +}; + +// helper class for unary_ops with alpha and gamma +class FuncWithAlphaGamma { + public: + FuncWithAlphaGamma(const Node& node) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper attrs(&ctx); + ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); + ORT_ENFORCE(attrs.GetAttr("gamma", &gamma_).IsOK()); + } + + protected: + float alpha_; + float gamma_; +}; + +// helper macro declares unary_ops helper class without attribute +#define FuncClass(name) \ + class Func##name { \ + public: \ + Func##name(const Node&) {} \ + tvm::Tensor operator()(const tvm::Tensor& X) const { \ + return name(X); \ + } \ + } + +// helper macro declares unary_ops helper class with alpha +#define FuncClassAlpha(name) \ + class Func##name : public FuncWithAlpha { \ + public: \ + Func##name(const Node& node) : FuncWithAlpha(node) {} \ + tvm::Tensor operator()(const tvm::Tensor& X) const { \ + return name(X, alpha_); \ + } \ + } + +// helper macro declares unary_ops helper class with alpha and beta +#define FuncClassAlphaBeta(name) \ + class Func##name : public FuncWithAlphaBeta { \ + public: \ + Func##name(const Node& node) : FuncWithAlphaBeta(node) {} \ + tvm::Tensor operator()(const tvm::Tensor& X) const { \ + return name(X, alpha_, beta_); \ + } \ + } + +// helper macro declares unary_ops helper class with alpha and gamma +#define FuncClassAlphaGamma(name) \ + class Func##name : public FuncWithAlphaGamma { \ + public: \ + Func##name(const Node& node) : FuncWithAlphaGamma(node) {} \ + tvm::Tensor operator()(const tvm::Tensor& X) const { \ + return name(X, alpha_, gamma_); \ + } \ + } + +FuncClass(Erf); +FuncClass(Exp); +FuncClass(Log); +FuncClassAlphaBeta(ParametricSoftplus); +FuncClassAlphaBeta(ScaledTanh); +FuncClassAlphaGamma(Selu); +FuncClass(Sigmoid); +FuncClass(Softplus); +FuncClass(Tanh); + +// helper macro defines Evaluate of UNARY_OP OpIRCreators +#define UNARY_OP(name) \ + Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(name)::Evaluate( \ + const tvm::Array& inputs, \ + const Node& node, \ + tvm_codegen::CodeGenContext&, \ + tvm::Array& outputs) { \ + tvm::Tensor Y = Func##name(node)(inputs[0]); \ + outputs.push_back(Y); \ + return Status::OK(); \ + } + +// helper local macros to replace some calls in LIST_UNARY_OPS +LIST_X86_UNARY_OPS() + +#undef UNARY_OP + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/slice.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/slice.cc new file mode 100644 index 0000000000000..d303f2f6411d7 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/slice.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/codegen/mti/tensor/tile.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" + +namespace onnxruntime { +namespace tvm_codegen { + +// Forwarding +Status SliceCommon(const tvm::Array& inputs, + const Node& node, + tvm::Array& outputs, + const std::vector& starts, + const std::vector& ends, + const std::vector& axes); + +} // namespace tvm_codegen + +namespace nuphar { + +// Evaluate of Slice OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Slice)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + + std::vector> slice_params; + int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); + if (version <= 9) { + std::vector starts, ends, axes; + ORT_RETURN_IF_ERROR(info.GetAttrs("starts", starts)); + ORT_RETURN_IF_ERROR(info.GetAttrs("ends", ends)); + ORT_RETURN_IF_NOT(starts.size() == ends.size()); + axes = info.GetAttrsOrDefault("axes"); + slice_params.push_back(starts); + slice_params.push_back(ends); + slice_params.push_back(axes); + } else { + // for opset 10 Slice, input 1/2/3/4 are starts/ends/axes/steps + // while axes and steps are optional + ORT_ENFORCE(node.InputDefs().size() < 5, "Slice opset 10: steps is not supported yet"); + for (size_t i = 1; i < 4; ++i) { + if (i < node.InputDefs().size()) { + const auto* tensor = ctx_nuphar->GetOrtInitializerTensor(node.InputDefs()[i]->Name()); + if (tensor) { + if (tensor->DataType() == DataTypeImpl::GetType()) { + const int64_t* data = tensor->Data(); + slice_params.push_back(std::vector(data, data + tensor->Shape().Size())); + } else { + const int32_t* data = tensor->Data(); + slice_params.push_back(std::vector(data, data + tensor->Shape().Size())); + } + continue; + } + } + slice_params.push_back(std::vector()); + } + } + return tvm_codegen::SliceCommon(inputs, node, outputs, slice_params[0], slice_params[1], slice_params[2]); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/tile.cc b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/tile.cc new file mode 100644 index 0000000000000..2841206eea8c0 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/op_ir_creator/tensor/tile.cc @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h" + +#include "core/codegen/mti/tensor/tile.h" +#include "core/framework/op_kernel_info.h" +#include "core/providers/common.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" + +namespace onnxruntime { +namespace nuphar { + +// Evaluate of Tile OpIRCreator +Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Tile)::Evaluate( + const tvm::Array& inputs, + const Node& node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm::Array& outputs) { + ProtoHelperNodeContext ctx(node); + OpNodeProtoHelper info(&ctx); + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + const auto* repeats = ctx_nuphar->GetOrtInitializerTensor(node.InputDefs()[1]->Name()); + ORT_RETURN_IF_NOT(repeats != nullptr); + ORT_RETURN_IF_NOT(repeats->Shape().Size() == gsl::narrow(inputs[0]->shape.size())); + const int64_t* repeats_data = repeats->Data(); + const auto repeats_vector = std::vector(repeats_data, repeats_data + inputs[0]->shape.size()); + tvm::Tensor Y = tvm_codegen::Tile(inputs[0], repeats_vector, node.Name() + "_Tile"); + outputs.push_back(Y); + return Status::OK(); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/analysis_schedule.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/analysis_schedule.cc new file mode 100644 index 0000000000000..485aa328d75e2 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/analysis_schedule.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h" + +#include "core/codegen/passes/scheduler/schedule_utils.h" + +namespace onnxruntime { +namespace nuphar { + +// This is for UseCount +bool TVM_SCHEDULER_CLASS(True, NupharX86UseCount)::Evaluate( + const tvm::Tensor& tensor, + const Node*, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + bool status_vec = TryVectorizationX86(tensor, ctx_sched); + bool status_r_and_c = tvm_codegen::InsertRootScheduleAndClosure(tensor, ctx_sched); + return status_vec || status_r_and_c; +} + +bool TVM_SCHEDULER_CLASS(False, NupharX86UseCount)::Evaluate( + const tvm::Tensor& tensor, + const Node*, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + return tvm_codegen::TryInlineSchedule(tensor, ctx_sched); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.cc new file mode 100644 index 0000000000000..736c92ed4fc1c --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h" + +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" + +namespace onnxruntime { +namespace nuphar { + +tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(NupharX86UseCount):: + Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext& ctx) { + if (nullptr == node) + return nullptr; + + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx); + bool reused = Promote(ctx_nuphar->GetGraphStats())->NodeUseCount(node) > 1; + bool cheap_node_reused = Promote(ctx_nuphar->GetGraphStats())->IsCheapNodeReuse(node); + + if (reused && cheap_node_reused) { + return DispatcherBase::Get("True"); + } + return DispatcherBase::Get("False"); +} + +tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(NupharX86PartialResult):: + Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext&) { + if (nullptr == node) + return DispatcherBase::Get("True"); + return nullptr; +} + +tvm_codegen::Scheduler* SCHEDULE_DISPATCHER_CLASS(NupharX86Tensorize):: + Find(const tvm::Tensor& tensor, const Node* node, tvm_codegen::CodeGenContext&) { + if (nullptr == node) + return nullptr; + + // special checking to bypass tensorization + // when fall back to extern function call + if (tensor->op->InputTensors().size() > 0) { + auto& imatmul = tensor->op->InputTensors()[0]; + auto extern_op = imatmul->op.as(); + // Extern function call + if (nullptr != extern_op) + return nullptr; + } + + return DispatcherBase::Get(node->OpType()); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h new file mode 100644 index 0000000000000..766d251235ea4 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/codegen/passes/scheduler/tvm_scheduler.h" +#include + +namespace onnxruntime { +namespace nuphar { + +DECLARE_SCHEDULE_DISPATCHER_CLASS(NupharX86UseCount) +DECLARE_SCHEDULE_DISPATCHER_CLASS(NupharX86PartialResult) +DECLARE_SCHEDULE_DISPATCHER_CLASS(NupharX86Tensorize) + +DECLARE_TVM_SCHEDULER_CLASS(Extern, NupharX86TVMRule) +DECLARE_TVM_SCHEDULER_CLASS(Reduce, NupharX86TVMRule) + +DECLARE_TVM_SCHEDULER_CLASS(MatMulInteger, NupharX86Tensorize) +DECLARE_TVM_SCHEDULER_CLASS(MatMulInteger16, NupharX86Tensorize) +DECLARE_TVM_SCHEDULER_CLASS(Softmax, NupharX86OrtOpType) +DECLARE_TVM_SCHEDULER_CLASS(Gemm, NupharX86OrtOpType) +DECLARE_TVM_SCHEDULER_CLASS(Conv, NupharX86OrtOpType) +DECLARE_TVM_SCHEDULER_CLASS(MatMul, NupharX86OrtOpType) +DECLARE_TVM_SCHEDULER_CLASS(Split, NupharX86OrtOpType) + +DECLARE_TVM_SCHEDULER_CLASS(True, NupharX86UseCount) +DECLARE_TVM_SCHEDULER_CLASS(False, NupharX86UseCount) + +DECLARE_TVM_SCHEDULER_CLASS(True, NupharX86PartialResult) + +// utilities +bool TryVectorizationX86( + const tvm::Tensor& tensor, + tvm_codegen::ScheduleContext& ctx); + +bool InputRootScheduleWithVectorizationX86( + const tvm::Tensor& tensor, + tvm_codegen::ScheduleContext& ctx); + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/ort_type_schedule.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/ort_type_schedule.cc new file mode 100644 index 0000000000000..2acdd995826b3 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/ort_type_schedule.cc @@ -0,0 +1,270 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h" + +#include "core/providers/nuphar/common/analysis/subgraph_codegen_stats.h" +#include "core/providers/nuphar/compiler/nuphar_codegen_ctx.h" +#include "core/codegen/passes/scheduler/schedule_utils.h" +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h" +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.h" +#include "core/framework/op_kernel_info.h" +#include + +namespace onnxruntime { +namespace nuphar { + +bool TryVectorizationX86( + const tvm::Tensor& tensor, + tvm_codegen::ScheduleContext& ctx) { + // TODO change it to the value from Target + int64_t natural_vector_size = 16; + + return TryVectorization(tensor, natural_vector_size, ctx); +} + +bool InputRootScheduleWithVectorizationX86( + const tvm::Tensor& tensor, + tvm_codegen::ScheduleContext& ctx) { + bool status = false; + for (auto& t : tensor->op->InputTensors()) { + if (t->op->InputTensors().size() > 0) { + bool status_vec = TryVectorizationX86(t, ctx); + bool status_root = InsertRootSchedule(t, ctx); + status = status || status_root || status_vec; + } + } + return status; +} + +bool TVM_SCHEDULER_CLASS(Softmax, NupharX86OrtOpType)::Evaluate( + const tvm::Tensor& tensor, + const Node*, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + bool status_softmax_itself = TryInlineSchedule(tensor, ctx_sched); + + // compute root the exp since it is reused more than once + auto& tensor_exp = tensor->op->InputTensors()[0]; + bool status_vec = TryVectorizationX86(tensor_exp, ctx_sched); + bool status_root = InsertRootSchedule(tensor_exp, ctx_sched); + return status_softmax_itself || status_vec || status_root; +} + +bool TVM_SCHEDULER_CLASS(Split, NupharX86OrtOpType)::Evaluate( + const tvm::Tensor& tensor, + const Node*, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + auto& tensor_split_input = tensor->op->InputTensors()[0]; + // force inline for split since to avoid extra copy + bool status_split_itself = TryInlineSchedule(tensor, ctx_sched); + + // add root for split's inputs to avoid inline of the inputs + bool status_vec = TryVectorizationX86(tensor_split_input, ctx_sched); + bool status_input_root = InsertRootSchedule(tensor_split_input, ctx_sched); + return status_split_itself || status_vec || status_input_root; +} + +// Illustration purpose only for tensorization +static Status MatMulTensorization(const tvm::Tensor& tensor, + tvm_codegen::ScheduleContext& ctx) { + if (tensor->shape.size() != 2) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Gemm output shape should be 2D"); + + // TODO: remove compute_root + InsertRootScheduleAndClosure(tensor, ctx); + +// Demo for Tensorization with llvm extern function +#if 1 + int32_t factor_int32 = 16; + NaiveLLVMExternGemvTensorization tensorization_method("NaiveLLVMExternGemv_Example", {factor_int32, factor_int32}); + + auto shape = tensorization_method.Shape(); + auto compute_op = tensor->op.as(); + auto xy = compute_op->axis; + auto x = xy[0]; + auto y = xy[1]; + auto z = compute_op->reduce_axis[0]; + + tvm::IterVar yo, yi; + ctx.schedule[tensor->op].split(y, shape[0], &yo, &yi); + tvm::IterVar zo, zi; + ctx.schedule[tensor->op].split(z, shape[1], &zo, &zi); + ctx.schedule[tensor->op].reorder({x, yo, zo, yi, zi}); + ctx.schedule[tensor->op].tensorize(yi, tensorization_method.CreateTensorIntrin()); + ctx.schedule[tensor->op].pragma(yo, "import_llvm", tensorization_method.LLVMImportDef()); +#endif + +// Demo for Tensorization with llvm intrisic IR +#if 0 + NaiveLLVMIRGemvTensorization tensorization_method("NaiveLLVMIRGemv_Example"); + + auto shape = tensorization_method.Shape(); + auto compute_op = tensor->op.as(); + auto xy = compute_op->axis; + auto x = xy[0]; + auto y = xy[1]; + auto z = compute_op->reduce_axis[0]; + + tvm::IterVar yo, yi; + ctx.schedule[tensor->op].split(y, shape[0], &yo, &yi); + tvm::IterVar zo, zi; + ctx.schedule[tensor->op].split(z, shape[1], &zo, &zi); + ctx.schedule[tensor->op].reorder({x, yo, zo, yi, zi}); + ctx.schedule[tensor->op].tensorize(yi, tensorization_method.CreateTensorIntrin()); +#endif + + return Status::OK(); +} + +// this is not tested in onnxruntime_test_all, since extern has higher priority +// don't register it +bool TVM_SCHEDULER_CLASS(Gemm, NupharX86OrtOpType)::Evaluate( + const tvm::Tensor& tensor, + const Node* node, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + ProtoHelperNodeContext ctx(*node); + OpNodeProtoHelper attrs(&ctx); + int64_t trans_A_64, trans_B_64; + bool status_a = attrs.GetAttr("transA", &trans_A_64).IsOK(); + ORT_ENFORCE(status_a); + bool status_b = attrs.GetAttr("transB", &trans_B_64).IsOK(); + ORT_ENFORCE(status_b); + + if (trans_A_64 == 0 && trans_B_64 == 1) { + return MatMulTensorization(tensor, ctx_sched).IsOK(); + } + return InsertRootSchedule(tensor, ctx_sched); +} + +// OLD code from Conv schedule +static Status ConvScheduleX86(const tvm::Tensor& tensor, + NupharCodeGenCtx& ctx_codegen, + tvm_codegen::ScheduleContext& ctx_sched, + int block_size) { + if (tensor->shape.size() != 4) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv output shape should be 4D"); + + InsertRootScheduleAndClosure(tensor, ctx_sched); + + auto compute_op = tensor->op.as(); + auto ncyx = compute_op->axis; + auto b = ncyx[0]; + auto oc = ncyx[1]; + auto y = ncyx[2]; + auto x = ncyx[3]; + auto ic = compute_op->reduce_axis[0]; + auto m = compute_op->reduce_axis[1]; + auto n = compute_op->reduce_axis[2]; + + tvm::Expr kfactor(4); // todo: this factor for vectorization is tuned for conv2d_performance on AVX2, will need to be addressed later + tvm::IterVar oc_chunk, oc_block; + ctx_sched.schedule[tensor->op].split(oc, kfactor, &oc_chunk, &oc_block); + + tvm::Expr factor(block_size); // factor for tiling and blocking + tvm::IterVar ic_chunk, ic_block; + ctx_sched.schedule[tensor->op].split(ic, factor, &ic_chunk, &ic_block); + + tvm::IterVar xo, xi; + ctx_sched.schedule[tensor->op].split(x, factor, &xo, &xi); + + ctx_sched.schedule[tensor->op].reorder({b, oc_chunk, y, xo, ic_chunk, m, n, ic_block, xi, oc_block}); + + if (ctx_codegen.GetCodeGenHandle()->enable_per_node_parallelized) { + tvm::Array fused_axis; + fused_axis.push_back(b); + fused_axis.push_back(oc_chunk); + fused_axis.push_back(y); + fused_axis.push_back(xo); + tvm::IterVar parallel_axis; + ctx_sched.schedule[tensor->op].fuse(fused_axis, ¶llel_axis); + ctx_sched.schedule[tensor->op].parallel(parallel_axis); + } + ctx_sched.schedule[tensor->op].vectorize(oc_block); + + return Status::OK(); +} + +bool TVM_SCHEDULER_CLASS(Conv, NupharX86OrtOpType)::Evaluate( + const tvm::Tensor& tensor, + const Node* node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm_codegen::ScheduleContext& ctx_sched) { + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + return ConvScheduleX86(tensor, *ctx_nuphar, ctx_sched, 16).IsOK(); +} // namespace tvm_codegen + +// seems only tested in double path +static Status MatMul_2DWeight_Schedule( + const tvm::Tensor& tensor_C, + NupharCodeGenCtx& ctx_codegen, + tvm_codegen::ScheduleContext& ctx_sched, + int block_size) { + // implementation adapted from: + // https://docs.tvm.ai/tutorials/optimize/opt_gemm.html#sphx-glr-tutorials-optimize-opt-gemm-py + InsertRootScheduleAndClosure(tensor_C, ctx_sched); + + // write cache, note this needs to happen before any axis ops in tensor_C + auto CC = ctx_sched.schedule.cache_write(tensor_C, "global"); + + const auto& C_axis = tensor_C->op.as()->axis; + auto C_rank = C_axis.size(); + auto x = C_axis[C_rank - 2]; + auto y = C_axis[C_rank - 1]; + tvm::Expr block(block_size); + tvm::IterVar xo, yo, xi, yi; + ctx_sched.schedule[tensor_C->op].tile(x, y, block, block, &xo, &yo, &xi, &yi); + ctx_sched.schedule[CC->op].compute_at(ctx_sched.schedule[tensor_C->op], yo); + + // new inner axes + const auto& CC_axis = CC->op.as()->axis; + auto xc = CC_axis[C_rank - 2]; + auto yc = CC_axis[C_rank - 1]; + + constexpr int num_unrolls = 4; + auto split_factor = tvm::Expr(num_unrolls); + auto k = ctx_sched.schedule[CC->op]->op.as()->reduce_axis[0]; + tvm::IterVar ko, ki; + ctx_sched.schedule[CC->op].split(k, split_factor, &ko, &ki); + tvm::Array reordered_axis; + for (size_t d = 0; d < C_rank - 2; ++d) + reordered_axis.push_back(CC_axis[d]); + reordered_axis.push_back(ko); + reordered_axis.push_back(xc); + reordered_axis.push_back(ki); + reordered_axis.push_back(yc); + ctx_sched.schedule[CC->op].reorder(reordered_axis); + ctx_sched.schedule[CC->op].unroll(ki); + ctx_sched.schedule[CC->op].vectorize(yc); + + if (ctx_codegen.GetCodeGenHandle()->enable_per_node_parallelized) { + // parallelize + tvm::Array fused_axis; + for (size_t d = 0; d < C_rank - 2; ++d) + fused_axis.push_back(C_axis[d]); + fused_axis.push_back(xo); + tvm::IterVar fused_xo; + ctx_sched.schedule[tensor_C->op].fuse(fused_axis, &fused_xo); + ctx_sched.schedule[tensor_C->op].parallel(fused_xo); + } + + return Status::OK(); +} + +bool TVM_SCHEDULER_CLASS(MatMul, NupharX86OrtOpType)::Evaluate( + const tvm::Tensor& tensor, + const Node* node, + tvm_codegen::CodeGenContext& ctx_codegen, + tvm_codegen::ScheduleContext& ctx_sched) { + NupharCodeGenCtx* ctx_nuphar = Promote(&ctx_codegen); + + if (tensor->dtype != HalideIR::Float(32)) { + return MatMul_2DWeight_Schedule(tensor, *ctx_nuphar, ctx_sched, 16).IsOK(); + } + return InsertRootSchedule(tensor, ctx_sched); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/partial_schedule.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/partial_schedule.cc new file mode 100644 index 0000000000000..4af5e016467ba --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/partial_schedule.cc @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/scheduler/nuphar_scheduler.h" + +#include "core/codegen/passes/scheduler/schedule_utils.h" + +namespace onnxruntime { +namespace nuphar { + +// This is for ReuseCount +bool TVM_SCHEDULER_CLASS(True, NupharX86PartialResult)::Evaluate( + const tvm::Tensor& tensor, + const Node*, + tvm_codegen::CodeGenContext&, + tvm_codegen::ScheduleContext& ctx_sched) { + return TryInlineSchedule(tensor, ctx_sched); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.cc new file mode 100644 index 0000000000000..74470087d53cd --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "intrin_gemv_16bit.h" +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_utilities.h" +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +Gemv16bitTensorization::Gemv16bitTensorization(const std::string& name, const std::vector& vshape) + : TensorizeBase(name, "Gemv16bitTensorization_Parameter", {vshape[0], vshape[1]}) {} + +tvm::TensorIntrin Gemv16bitTensorization::CreateTensorIntrin() { + tvm::Expr m(shape_[0]); + tvm::Expr l(shape_[1]); + + auto a = tvm::placeholder({l}, HalideIR::Int(16)); + auto b = tvm::placeholder({m, l}, HalideIR::Int(16)); + auto k = tvm::reduce_axis({0, l}); + + auto c = tvm::compute({m}, [&](tvm::Var i) { + return tvm::sum(tvm::cast(HalideIR::Int(32), a(k)) * tvm::cast(HalideIR::Int(32), b(i, k)), {k}); + }); + + auto a_buf = tvm::BufferNode::make( + tvm::Var("a", tvm::Handle()), + a->dtype, + a->shape, + /*strides*/ {1}, + tvm::Var("a_offset"), + "a", + "", + 0, + /*offset_factor*/ 1); + + auto b_buf = tvm::BufferNode::make( + tvm::Var("b", tvm::Handle()), + b->dtype, + b->shape, + /*strides*/ {tvm::Var("s1"), 1}, + tvm::Var("b_offset"), + "b", + "", + 0, + /*offset_factor*/ 1); + + auto c_buf = tvm::BufferNode::make( + tvm::Var("c", tvm::Handle()), + c->dtype, + c->shape, + /*strides*/ {1}, + tvm::Var("c_offset"), + "c", + "", + 0, + /*offset_factor*/ 1); + + int h_unroll = shape_[1] / 16; + auto sum_int32x8 = tvm::make_const(HalideIR::Int(32, 8), 0); + + for (int i = 0; i < h_unroll; ++i) { + auto a_int16x16 = a_buf.vload({i * 16}, HalideIR::Int(16, 16)); + auto b_int16x16 = b_buf.vload({0, i * 16}, HalideIR::Int(16, 16)); + + auto axb_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.pmadd.wd", + {a_int16x16, b_int16x16}); + sum_int32x8 += axb_int32x8; + } + + sum_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.phadd.d", + {sum_int32x8, sum_int32x8}); + sum_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.phadd.d", + {sum_int32x8, sum_int32x8}); + + auto sum_int32x4_l = tvm_codegen::VectorLow(sum_int32x8); + auto sum_int32x4_h = tvm_codegen::VectorHigh(sum_int32x8); + auto sum_int32x4 = sum_int32x4_l + sum_int32x4_h; + auto sum_int32x1 = tvm_codegen::ExtractElement(sum_int32x4, 0); + + auto reset = c_buf.vstore({0}, tvm::make_const(HalideIR::Int(32, 1), 0)); + auto body = c_buf.vstore({0}, sum_int32x1); + auto update = c_buf.vstore({0}, sum_int32x1 + c_buf.vload({0}, HalideIR::Int(32, 1))); + + return tvm::TensorIntrinNode::make( + "intrin_gemv_16bit", + c->op, + {a, b}, + {a_buf, b_buf, c_buf}, + body, + reset, + update); +} +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.h new file mode 100644 index 0000000000000..0f9e460c632aa --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_16bit.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tensorize_base.h" + +namespace onnxruntime { +namespace nuphar { + +class Gemv16bitTensorization : public tvm_codegen::TensorizeBase { + public: + Gemv16bitTensorization(const std::string& name, const std::vector& vshape); + + virtual ~Gemv16bitTensorization() = default; + + tvm::TensorIntrin CreateTensorIntrin() override; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.cc new file mode 100644 index 0000000000000..4a4dc1695303e --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.cc @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "intrin_gemv_8bit.h" +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_utilities.h" +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +Gemv8bitTensorization::Gemv8bitTensorization(const std::string& name, const std::vector& vshape) + : TensorizeBase(name, "Gemv8bitTensorization_Parameter", {vshape[0], vshape[1]}) {} + +tvm::TensorIntrin Gemv8bitTensorization::CreateTensorIntrin() { + tvm::Expr m(shape_[0]); + tvm::Expr l(shape_[1]); + + auto a = tvm::placeholder({l}, HalideIR::UInt(8)); + auto b = tvm::placeholder({m, l}, HalideIR::Int(8)); + auto k = tvm::reduce_axis({0, l}); + + auto c = tvm::compute({m}, [&](tvm::Var i) { + return tvm::sum(tvm::cast(HalideIR::Int(32), a(k)) * tvm::cast(HalideIR::Int(32), b(i, k)), {k}); + }); + + auto a_buf = tvm::BufferNode::make( + tvm::Var("a", tvm::Handle()), + a->dtype, + a->shape, + /*strides*/ {1}, + tvm::Var("a_offset"), + "a", + "", + 0, + /*offset_factor*/ 1); + + auto b_buf = tvm::BufferNode::make( + tvm::Var("b", tvm::Handle()), + b->dtype, + b->shape, + /*strides*/ {tvm::Var("s1"), 1}, + tvm::Var("b_offset"), + "b", + "", + 0, + /*offset_factor*/ 1); + + auto c_buf = tvm::BufferNode::make( + tvm::Var("c", tvm::Handle()), + c->dtype, + c->shape, + /*strides*/ {1}, + tvm::Var("c_offset"), + "c", + "", + 0, + /*offset_factor*/ 1); + + int h_unroll = shape_[1] / 32; + auto sum_int32x8 = tvm::make_const(HalideIR::Int(32, 8), 0); + auto one = tvm::make_const(HalideIR::Int(16, 16), 1); + + for (int i = 0; i < h_unroll; ++i) { + auto a_uint8x32 = a_buf.vload({i * 32}, HalideIR::UInt(8, 32)); + auto b_int8x32 = b_buf.vload({0, i * 32}, HalideIR::Int(8, 32)); + + auto axb_int16x16 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(16, 16), + "llvm.x86.avx2.pmadd.ub.sw", + {a_uint8x32, b_int8x32}); + auto axb_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.pmadd.wd", + {axb_int16x16, one}); + sum_int32x8 += axb_int32x8; + } + + sum_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.phadd.d", + {sum_int32x8, sum_int32x8}); + sum_int32x8 = tvm_codegen::LLVMIntrinsic(HalideIR::Int(32, 8), + "llvm.x86.avx2.phadd.d", + {sum_int32x8, sum_int32x8}); + + auto sum_int32x4_l = tvm_codegen::VectorLow(sum_int32x8); + auto sum_int32x4_h = tvm_codegen::VectorHigh(sum_int32x8); + auto sum_int32x4 = sum_int32x4_l + sum_int32x4_h; + auto sum_int32x1 = tvm_codegen::ExtractElement(sum_int32x4, 0); + + auto reset = c_buf.vstore({0}, tvm::make_const(HalideIR::Int(32, 1), 0)); + auto body = c_buf.vstore({0}, sum_int32x1); + auto update = c_buf.vstore({0}, sum_int32x1 + c_buf.vload({0}, HalideIR::Int(32, 1))); + + return tvm::TensorIntrinNode::make( + "intrin_gemv_8bit", + c->op, + {a, b}, + {a_buf, b_buf, c_buf}, + body, + reset, + update); +} +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.h new file mode 100644 index 0000000000000..83d366ca1ecde --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_8bit.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tensorize_base.h" + +namespace onnxruntime { +namespace nuphar { + +class Gemv8bitTensorization : public tvm_codegen::TensorizeBase { + public: + Gemv8bitTensorization(const std::string& name, const std::vector& vshape); + + virtual ~Gemv8bitTensorization() = default; + + tvm::TensorIntrin CreateTensorIntrin() override; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.cc new file mode 100644 index 0000000000000..0aedc8178c72f --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h" +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.h" +#include +#include + +namespace onnxruntime { +namespace nuphar { + +const char* gemv_update_func_name = "gemv_update"; +const char* gemv_reset_func_name = "gemv_reset"; + +NaiveLLVMExternGemvTensorization::NaiveLLVMExternGemvTensorization(const std::string& name, + const std::vector& shape) + : TensorizeWithLLVMImport(name, "NaiveLLVMExternGemvTensorization_Parameter", shape) {} + +tvm::TensorIntrin NaiveLLVMExternGemvTensorization::CreateTensorIntrin() { + tvm::Expr m(shape_[0]); + tvm::Expr l(shape_[1]); + + auto a = tvm::placeholder({l}); + auto b = tvm::placeholder({m, l}); + auto k = tvm::reduce_axis({0, l}); + + auto c = tvm::compute({m}, [&](tvm::Var i) { + return tvm::sum(a(k) * b(i, k), {k}); + }); + + auto a_buf = tvm::BufferNode::make( + tvm::Var("a", tvm::Handle()), + a->dtype, + a->shape, + /*strides*/ {1}, + tvm::Var("a_offset"), + "a", + "", + 0, + /*offset_factor*/ 1); + + auto b_buf = tvm::BufferNode::make( + tvm::Var("b", tvm::Handle()), + b->dtype, + b->shape, + /*strides*/ {tvm::Var("s1"), 1}, + tvm::Var("b_offset"), + "b", + "", + 0, + /*offset_factor*/ 1); + + auto c_buf = tvm::BufferNode::make( + tvm::Var("c", tvm::Handle()), + c->dtype, + c->shape, + /*strides*/ {1}, + tvm::Var("c_offset"), + "c", + "", + 0, + /*offset_factor*/ 1); + + auto body = tvm::ir::Call::make( + HalideIR::Type(HalideIR::Type::Int, 32, 1), + gemv_update_func_name, + { + c_buf.access_ptr(static_cast(tvm::AccessMask::kWrite)), + a_buf.access_ptr(static_cast(tvm::AccessMask::kRead)), + b_buf.access_ptr(static_cast(tvm::AccessMask::kRead)), + m, + l, + /*stride*/ b_buf->strides[0], + }, + tvm::ir::Call::CallType::Extern); + + auto reduce_init = tvm::ir::Call::make( + HalideIR::Type(HalideIR::Type::Int, 32, 1), + gemv_reset_func_name, + { + c_buf.access_ptr(static_cast(tvm::AccessMask::kWrite)), + m, + }, + tvm::ir::Call::CallType::Extern); + + auto reduce_update = body; + + return tvm::TensorIntrinNode::make( + "intrin_gemv_ll_extern", + c->op, + {a, b}, + {a_buf, b_buf, c_buf}, + tvm::ir::Evaluate::make(body), + tvm::ir::Evaluate::make(reduce_init), + tvm::ir::Evaluate::make(reduce_update)); +} + +const std::string NaiveLLVMExternGemvTensorization::LLVMImportDef() { + return std::string(gemv_stubs_ir); +} + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h new file mode 100644 index 0000000000000..6a227c746a9e5 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_extern.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tensorize_base.h" + +namespace onnxruntime { +namespace nuphar { + +TENSORIZE_CLASS_WITH_LLVM_IMPORT(NaiveLLVMExternGemvTensorization) + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.cc b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.cc new file mode 100644 index 0000000000000..cdaefe9f87e4c --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "intrin_gemv_ll_ir.h" + +#include "core/providers/nuphar/compiler/x86/scheduler/tensorize/tensorize_utilities.h" +#include +#include +#include + +namespace onnxruntime { +namespace nuphar { + +const int32_t dim0 = 1; +const int32_t dim1 = 8; + +NaiveLLVMIRGemvTensorization::NaiveLLVMIRGemvTensorization(const std::string& name) + : TensorizeBase(name, "NaiveLLVMIRGemvTensorization_Parameter", {dim0, dim1}) {} + +tvm::TensorIntrin NaiveLLVMIRGemvTensorization::CreateTensorIntrin() { + tvm::Expr m(dim0); + tvm::Expr l(dim1); + + auto a = tvm::placeholder({l}); + auto b = tvm::placeholder({m, l}); + auto k = tvm::reduce_axis({0, l}); + + auto c = tvm::compute({m}, [&](tvm::Var i) { + return tvm::sum(a(k) * b(i, k), {k}); + }); + + auto a_buf = tvm::BufferNode::make( + tvm::Var("a", tvm::Handle()), + a->dtype, + a->shape, + /*strides*/ {1}, + tvm::Var("a_offset"), + "a", + "", + 0, + /*offset_factor*/ 1); + + auto b_buf = tvm::BufferNode::make( + tvm::Var("b", tvm::Handle()), + b->dtype, + b->shape, + /*strides*/ {tvm::Var("s1"), 1}, + tvm::Var("b_offset"), + "b", + "", + 0, + /*offset_factor*/ 1); + + auto c_buf = tvm::BufferNode::make( + tvm::Var("c", tvm::Handle()), + c->dtype, + c->shape, + /*strides*/ {1}, + tvm::Var("c_offset"), + "c", + "", + 0, + /*offset_factor*/ 1); + + auto a_float32x8 = a_buf.vload({0}, HalideIR::Float(32, 8)); + auto b_float32x8 = b_buf.vload({0, 0}, HalideIR::Float(32, 8)); + auto z_float32x8 = tvm::make_const(HalideIR::Float(32, 8), 0); + + auto axb = tvm_codegen::LLVMIntrinsic(HalideIR::Float(32, 8), + "llvm.x86.fma.vfmadd.ps.256", + {a_float32x8, + b_float32x8, + z_float32x8}); + + auto sum = tvm_codegen::ExtractElement(axb, 0); + + for (int i = 1; i < 8; ++i) { + auto z0 = tvm_codegen::ExtractElement(axb, i); + sum += z0; + } + + auto body = c_buf.vstore({0}, sum); + auto reset = c_buf.vstore({0}, tvm::make_const(HalideIR::Float(32, 1), 0)); + auto update = c_buf.vstore({0}, sum + c_buf.vload({0}, HalideIR::Float(32, 1))); + + return tvm::TensorIntrinNode::make( + "intrin_gemv_ll_ir", + c->op, + {a, b}, + {a_buf, b_buf, c_buf}, + body, + reset, + update); +} +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.h new file mode 100644 index 0000000000000..7dad78b35723f --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/intrin_gemv_ll_ir.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tensorize_base.h" + +namespace onnxruntime { +namespace nuphar { + +class NaiveLLVMIRGemvTensorization : public tvm_codegen::TensorizeBase { + public: + NaiveLLVMIRGemvTensorization(const std::string& name); + + virtual ~NaiveLLVMIRGemvTensorization() = default; + + tvm::TensorIntrin CreateTensorIntrin() override; +}; + +} // namespace nuphar +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.cpp b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.cpp new file mode 100644 index 0000000000000..7eff192b1829e --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +extern "C" int gemv_update(float* cc, float* aa, float* bb, int m, int l, int stride) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < l; ++j) { + cc[i] += aa[j] * bb[i * stride + j]; + } + } + return 0; +} + +extern "C" int gemv_reset(float* cc, int m) { + for (int i = 0; i < m; ++i) { + cc[i] = 0.0; + } + return 0; +} diff --git a/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.h b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.h new file mode 100644 index 0000000000000..89fdfe5e51148 --- /dev/null +++ b/onnxruntime/core/providers/nuphar/compiler/x86/scheduler/tensorize/ll/gemv_impl.h @@ -0,0 +1,137 @@ +// The string in this file is generated using clang: +// clang++.exe -fno-preserve-as-comments -S -emit-llvm gemv_impl.cpp + +namespace onnxruntime { +namespace nuphar { + +const char* gemv_stubs_ir = R"gemv_stub_escape( +; ModuleID = 'gemv_stubs.cpp' +source_filename = "gemv_stubs.cpp" +target datalayout = "e-m:w-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc-windows-msvc19.11.25548" + +; Function Attrs: noinline nounwind optnone uwtable +define i32 @gemv_update(float*, float*, float*, i32, i32, i32) #0 { + %7 = alloca i32, align 4 + %8 = alloca i32, align 4 + %9 = alloca i32, align 4 + %10 = alloca float*, align 8 + %11 = alloca float*, align 8 + %12 = alloca float*, align 8 + %13 = alloca i32, align 4 + %14 = alloca i32, align 4 + store i32 %5, i32* %7, align 4 + store i32 %4, i32* %8, align 4 + store i32 %3, i32* %9, align 4 + store float* %2, float** %10, align 8 + store float* %1, float** %11, align 8 + store float* %0, float** %12, align 8 + store i32 0, i32* %13, align 4 + br label %15 + +;