Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Improved error message for matmul shape mismatch #16308

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/relax/op/tensor/linear_algebra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul);

StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
Expr lhs = call->args[0];
Expr rhs = call->args[1];
TensorStructInfo x1_sinfo = input_sinfo[0];
TensorStructInfo x2_sinfo = input_sinfo[1];

Expand Down Expand Up @@ -75,10 +77,19 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
}
int x1_ndim = x1_sinfo->ndim;
int x2_ndim = x2_sinfo->ndim;
if (x1_ndim == 0 || x2_ndim == 0) {
if (x1_ndim == 0) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Matmul requires both inputs to have at least 1 dimension. However, "
<< (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor.");
<< "Matmul operands must not be scalar. "
<< "However, the expression " << call << " has a LHS of " << lhs
<< " with struct info " << x1_sinfo
<< ", which is scalar (zero-dimensional) tensor.");
}
if (x2_ndim == 0) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Matmul operands must not be scalar. "
<< "However, the expression " << call << " has a RHS of " << rhs
<< " with struct info " << x2_sinfo
<< ", which is scalar (zero-dimensional) tensor.");
}

int x1_prepended = 0;
Expand Down Expand Up @@ -120,9 +131,11 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2];
if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Matmul requires the reduction length of x1 and x2 to be equal. However, "
"the reduction lengths of x1 and x2 are "
<< x1_reduction_length << " and " << x2_reduction_length << " respectively.");
<< "Matmul requires the reduction length of the operands to be equal. "
<< "However, the LHS " << lhs << " has shape " << x1_sinfo->shape
<< ", while the RHS " << rhs << " has shape " << x2_sinfo->shape
<< ". The reduction dimensions of " << x1_reduction_length << " and "
<< x2_reduction_length << " are not equal.");
}

Array<PrimExpr> output_shape = output_shape_prefix.value();
Expand Down
Loading