Skip to content

Commit

Permalink
[Relay] Implement SoftmaxRel for softmax operators. (#11728)
Browse files Browse the repository at this point in the history
* Implement `SoftmaxRel` for softmax operators.

* Print better error message for wrong axis.
  • Loading branch information
wzh99 authored Jun 16, 2022
1 parent 24010db commit 6732a9e
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
// relay.softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);

bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
ICHECK(param != nullptr);
int axis = param->axis;
int ndim = static_cast<int>(data->shape.size());
if (axis >= ndim || axis < -ndim) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Wrong axis (" << axis << ") not in expected range: ["
<< -ndim << ", " << ndim << ")");
return false;
}

reporter->Assign(types[1], types[0]);
return true;
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
Expand All @@ -420,7 +441,7 @@ RELAY_REGISTER_OP("nn.softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Softmax", SoftmaxRel);

// relay.fast_softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
Expand All @@ -447,7 +468,7 @@ RELAY_REGISTER_OP("nn.fast_softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Softmax", SoftmaxRel);

// relay.nn.log_softmax
TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) {
Expand All @@ -471,7 +492,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.add_type_rel("Softmax", SoftmaxRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SoftmaxAttrs>();
Expand Down

0 comments on commit 6732a9e

Please sign in to comment.