Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Implement SoftmaxRel for softmax operators. #11728

Merged
merged 2 commits into from
Jun 16, 2022
Merged

[Relay] Implement SoftmaxRel for softmax operators. #11728

merged 2 commits into from
Jun 16, 2022

Conversation

wzh99
Copy link
Contributor

@wzh99 wzh99 commented Jun 15, 2022

This PR fixes #11684. I replace IdentityRel in nn.softmax, nn.fast_softmax and nn.log_softmax with a newly implemented SoftmaxRel so that the attribute axis is checked during type inference. For the test case shown in #11684, the following error is reported:

The axis is not in range [-1, 1)
Traceback (most recent call last):
  File "/Users/wzh/tvm-bug/bug_softmax_axis.py", line 8, in <module>
    mod = relay.transform.InferType()(mod)
  File "/Users/wzh/tvm-dev/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/Users/wzh/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000119ef03b4 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::$_6>(tvm::transform::$_6, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 948
  [bt] (7) 8   libtvm.dylib                        0x0000000119ee5964 tvm::transform::Pass::operator()(tvm::IRModule) const + 148
  [bt] (6) 7   libtvm.dylib                        0x0000000119ee5d71 tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 753
  [bt] (5) 6   libtvm.dylib                        0x0000000119ee6873 tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 819
  [bt] (4) 5   libtvm.dylib                        0x000000011b21ddfd tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 1933
  [bt] (3) 4   libtvm.dylib                        0x000000011b20d217 tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) + 135
  [bt] (2) 3   libtvm.dylib                        0x000000011afd2a2f tvm::relay::TypeSolver::Solve() + 1615
  [bt] (1) 2   libtvm.dylib                        0x0000000119b86699 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
  [bt] (0) 1   libtvm.dylib                        0x000000011b5a3508 tvm::runtime::Backtrace() + 24
  [bt] (8) 9   libtvm.dylib                        0x000000011b20d217 tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) + 135
  [bt] (7) 8   libtvm.dylib                        0x000000011afd285c tvm::relay::TypeSolver::Solve() + 1148
  [bt] (6) 7   libtvm.dylib                        0x000000011afd2dd0 tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const + 416
  [bt] (5) 6   libtvm.dylib                        0x000000011a08b154 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 20
  [bt] (4) 5   libtvm.dylib                        0x000000011a08b563 void tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 1027
  [bt] (3) 4   libtvm.dylib                        0x000000011acd163e tvm::relay::SoftmaxRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) + 942
  [bt] (2) 3   libtvm.dylib                        0x0000000119e6a08b tvm::DiagnosticContext::Render() + 459
  [bt] (1) 2   libtvm.dylib                        0x0000000119b86699 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
  [bt] (0) 1   libtvm.dylib                        0x000000011b5a3508 tvm::runtime::Backtrace() + 24
  File "/Users/wzh/tvm-dev/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [15:01:35] /Users/wzh/tvm-dev/src/ir/diagnostic.cc:105: DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

int ndim = static_cast<int>(data->shape.size());
if (axis >= ndim || axis < -ndim) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "The axis is not in range [" << -ndim << ", " << ndim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better err msg should indicate the wrong axis. For example,

...
<< "Wrong axis (" << axis << ") not in expected range: [" << -ndim << ", " << ndim << ")";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified the error message as you suggest.

@ganler
Copy link
Contributor

ganler commented Jun 15, 2022

Basically this PR lets Relay reject invalid softmax operator (axis >= rank) as earlier as the type inference phase (though such invalid cases will be rejected anyhow in later checks). @masahi Can you help decide if we want to merge this improvement? Thanks!

@masahi masahi merged commit 6732a9e into apache:main Jun 16, 2022
@wzh99 wzh99 deleted the softmax-rel branch June 16, 2022 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Type inference of nn.softmax does not reject invalid axis
3 participants