forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pointwise fusion for GPU (apache#15167)
* Beginning of RTC of pointwise ops * Code generation from the given JSON * add initial simple_partition_pass and use it for pointwise fusion * fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code * Fixes * Adding support for attribute inference for backward nodes when fusing * keep proper input ordering for fused Op * instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph * Fuse backward * fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch * excluse forward node fusion during the fusion of the nodes in the backward graph * Dealing with fused backward nodes inferattr * use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph * Adding support for other reqs in codegen * Fix * Cleaning * Change the TVM submodule * More cleaning * Making linter happy * Do fusion only if default context is GPU * Fixes for tests Add powerscalar and rpowerscalar, fix return type of zero and one Cleaning, fixing lint Go back to proper TVM submodule * Fix the TVM commit * Fix lint * Guard fusion with MXNET_USE_CUDA * Fix * Fix clang-tidy * Add erf and erfinv backward * Gluon support for fusion * Cleaning * Cleaning and allow shape/type change in FusedOp * Fixing Gluon bugs * Fixing after rebase * Fixing race condition and guarding against races when using NVRTC * Cleaning and renaming FusedOp to _FusedOp * Going easy on Windows compiler * Disable fusion on Windows for now * Refactor InferAttr and InferShapeAttr * Added slice and half2 support to FusedOp * Fix lint errors * Added multiple types support for vector loading/storing * add slice fusion when it's at the beginning of subgraphs * Removed constant ndim assumption in fused op * Fix memory alignment issue in slice for FusedOp * Fixes * Fix lint errors * Do not include cuda_fp16.h * Refactor fused op op lists * Make linter happy * Changes from review * Fixes after rebase * Expand FusedOp support for slice * Fix for fp16 _zeros and _ones * Fix * Moving aux functions to unnamed namespace and detail namespace -> fusion namespace * Disabling fusion if it alters topological order of inputs * Print code only when env variable is set * Fix * Fix lint and 2 tests that specify the same names for multiple inputs * Fixes from review and disabling fusion of slice with non-default step * Add amp_cast to fusion, fixes * Add amp_multicast and its backward to the list of support ops * Apply wording suggestions from code review Co-Authored-By: Aaron Markham <[email protected]> * Apply wording suggestions from code review Co-Authored-By: Aaron Markham <[email protected]> * Make clearer comment * Adding punctuation and capitalization to \brief descriptions * Fix * Fix * Add backward_cast to fusion * Adding unittests for fusion. Fix for erfinv_grad * Adding slice ops and add_n to tests * Fixes from review * Setting inplace option * Fix lint * Storing double in half * Retrigger CI * Slight relaxing of the relative tolerance in the test * Move the env variable check to the end * Fix a race condition between InferShape and scheduled Forward * Fix flakey test_fusion test involving fp32 erfinv op. * Fix from review * Added broadcast_like and slice_like to fused op * Minor fix and cleanup * Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like * Added axes support to slice_like * Added axis support to broadcast_like * Add fast_load_slice function to fused op code * Added runtime switch for choosing fast and slow slice kernel * Fix lint and warning * Going easy on Windows compiler (again) * Fix slice_like * Debug broadcast_like fusion * Fix lint * Fix lint * Trigger CI * Get rid of the initializer list * Fix backward calls with different gradient type * avoid cycle when adding node specific for inputs of subgraph for pointwise fusion * Fix lint * Add namespace to the fusion implementations * Set launch bounds on the fused kernel * Fix NumPy tests * Test showcasing an issue fixed in PR apache#16553 * Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b) Fix lint errors Fix lint * Fix a bug in cycle detection for inputs only op in pointwise fusion * Add comments to simple_partition_pass.h file
- Loading branch information
1 parent
437f1c7
commit 2a97260
Showing
20 changed files
with
3,862 additions
and
216 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file exec_utils.cc | ||
* \brief Implementation of executor util functions. | ||
*/ | ||
|
||
#include "exec_utils.h" | ||
#include <unordered_set> | ||
#include <unordered_map> | ||
#include <string> | ||
|
||
namespace mxnet { | ||
namespace common { | ||
|
||
void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables) { | ||
using nnvm::Node; | ||
using nnvm::NodePtr; | ||
using nnvm::NodeEntry; | ||
std::unordered_map<Node*, NodePtr> old_new; | ||
// use DFSVisit to copy all the nodes | ||
DFSVisit(src.outputs, [&old_new, copy_variables](const NodePtr& node) { | ||
NodePtr np; | ||
if (copy_variables || !node->is_variable()) { | ||
np = Node::Create(); | ||
np->attrs = node->attrs; | ||
} else { | ||
np = node; | ||
} | ||
old_new[node.get()] = std::move(np); | ||
}); | ||
// connect nodes of new graph | ||
for (const auto &kv : old_new) { | ||
for (const NodeEntry& e : kv.first->inputs) { | ||
Node *ptr = e.node.get(); | ||
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); | ||
} | ||
for (const NodePtr& p : kv.first->control_deps) { | ||
kv.second->control_deps.emplace_back(old_new[p.get()]); | ||
} | ||
} | ||
// set the head | ||
for (const NodeEntry &e : src.outputs) { | ||
(*dst).outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); | ||
} | ||
} | ||
|
||
bool CheckForInputNameDuplicates(const nnvm::IndexedGraph &idx) { | ||
std::unordered_set<std::string> names; | ||
for (const auto& nid : idx.input_nodes()) { | ||
const std::string &name = idx[nid].source->attrs.name; | ||
if (names.count(name)) { | ||
LOG(WARNING) << "Variable name " << name << " is used more than once!"; | ||
return false; | ||
} | ||
names.insert(name); | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace common | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.