From 017d184826ee738e67306c68f8c93e29ca830c89 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 5 May 2022 14:47:05 -0700 Subject: [PATCH 1/7] [Relay] Support i16, f16 scalars in Relay text While testing fp16 models for Collage discovered the Relay text format did not support f16. While adding that cleaned up scalar handling in general. However I left two inlined tests for 'is simple const' in place (fuse_ops.cc and memory_alloc.cc) since it's not clear whether they should remain specific to just {i,f}{32,64} or whether they can be replaced with the support::IsSimpleScalar central predicate. --- src/parser/parser.cc | 45 +---- src/parser/tokenizer.h | 105 +++++++---- src/printer/doc.cc | 3 - src/printer/relay_text_printer.cc | 63 ++----- src/printer/text_printer.h | 7 - src/support/scalars.cc | 209 +++++++++++++++++++++ src/support/scalars.h | 63 +++++++ tests/python/relay/test_ir_parser.py | 33 +++- tests/python/relay/test_ir_text_printer.py | 36 ++-- 9 files changed, 420 insertions(+), 144 deletions(-) create mode 100644 src/support/scalars.cc create mode 100644 src/support/scalars.h diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 9b15893092f7..f51e3e5c9737 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -35,10 +35,12 @@ #include +#include "../support/scalars.h" #include "./meta_ref.h" #include "./op_table.h" #include "./span_check.h" #include "./tokenizer.h" +#include "tvm/runtime/builtin_fp16.h" namespace tvm { namespace parser { @@ -534,49 +536,15 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { if (token->token_type == TokenType::kInteger) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - int64_t i = Downcast(token->data); - if (i > std::numeric_limits::max()) { - auto dtype = String2DLDataType("int64"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - array[0] = i; - return data; - } else { - auto dtype = String2DLDataType("int32"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - array[0] = i; - return data; - } + return support::IntImmToNDArray(Downcast(token->data)); } else if (token->token_type == TokenType::kFloat) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto float_imm = Downcast(token->data); - auto data = NDArray::Empty({}, float_imm->dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - // TODO(@jroesch): bounds checking - float value = float_imm->value; - array[0] = value; - return data; + return support::FloatImmToNDArray(Downcast(token->data)); } else { LOG(FATAL) << "internal error: should only call this function on numeric tokens"; - return NDArray(); + return {}; } } - /*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */ - NDArray BooleanToNDarray(bool value) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto dtype = String2DLDataType("bool"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - array[0] = value; - return data; - } - [[noreturn]] void ParseError(const Token& token, const std::string& msg) { throw std::runtime_error(msg); } @@ -1573,8 +1541,7 @@ class Parser { case TokenType::kBoolean: { Consume(TokenType::kBoolean); int64_t value = Downcast(next->data); - auto boolean = BooleanToNDarray(value); - Expr e = Constant(boolean, next->span); + Expr e = Constant(support::BoolToNDArray(value), next->span); ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f8098cf94100..a7d022a41907 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -34,6 +34,7 @@ #include #include +#include "../support/scalars.h" #include "./meta_ref.h" #include "./token.h" @@ -172,37 +173,18 @@ struct Tokenizer { } Token ParseNumber(bool is_pos, bool is_float, std::string number) { - ICHECK(number.size() > 0) << "an empty string is an invalid number"; + ICHECK(number.size() > 0) << "an empty string is an invalid float"; - if (!is_float) { - auto token = NewToken(TokenType::kInteger); - size_t index = 0; - int64_t value = 0; - try { - value = std::stoll(number, &index); - } catch (const std::invalid_argument& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } catch (const std::out_of_range& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } - if (number.size() <= index) { - value = is_pos ? value : -value; - if (value > std::numeric_limits::max()) { - token->data = tvm::IntImm(DataType::Int(64), value); - } else { - token->data = tvm::IntImm(DataType::Int(32), value); - } - return token; - } + Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger); + size_t suffix_pos = number.rfind(is_float ? 'f' : 'i'); + if (suffix_pos == std::string::npos) { + suffix_pos = number.size(); + } + std::string literal_text = number.substr(0, suffix_pos); + std::string suffix; + if (suffix_pos < number.size()) { + suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); } - auto token = NewToken(TokenType::kFloat); - - auto suffix_pos = number.rfind("f"); - - auto literal_text = number.substr(0, suffix_pos); - - auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); - int width = 32; if (suffix.size()) { @@ -217,9 +199,61 @@ struct Tokenizer { } } - double value = stod(literal_text); - value = is_pos ? value : -value; - token->data = tvm::FloatImm(DataType::Float(width), value); + if (is_float) { + double value = 0.0; + size_t index = 0; + try { + value = stod(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + bool clipped; + std::tie(token->data, clipped) = support::ValueToFloatImm(value, width); + if (clipped) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "floating point number `" << literal_text + << "` out of range for width " << width); + } + } else { + int64_t value = 0; + size_t index = 0; + try { + value = std::stoll(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + bool clipped; + std::tie(token->data, clipped) = support::ValueToIntImm(value, width); + if (clipped && suffix.empty()) { + // Without any i suffix the legacy behavior was to default to int64 if out of range + // for int32. + width = 64; + std::tie(token->data, clipped) = support::ValueToIntImm(value, width); + } + if (clipped) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) << "integer number `" << literal_text + << "` out of range for width " << width); + } + } + return token; } @@ -230,14 +264,13 @@ struct Tokenizer { } bool is_float = false; - - // Remove trailing floating point prefix. - if (More() && Peek() == 'f') { + if (More() && (Peek() == 'f' || Peek() == 'i')) { + is_float = Peek() == 'f'; + // Capture trailing width suffix ss << Next(); while (More() && IsNumeric(Peek())) { ss << Next(); } - is_float = true; } return ParseNumber(is_pos, is_float, ss.str()); } diff --git a/src/printer/doc.cc b/src/printer/doc.cc index f7d9fdfd7dfb..10977d083b56 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -53,9 +53,6 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: explicit DocText(std::string str) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not have tab or newline."; - } data_ = runtime::make_object(str); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 97231931ad88..bc66a22a7026 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -43,9 +43,11 @@ #include "../ir/attr_functor.h" #include "../parser/meta_ref.h" #include "../relay/analysis/dependency_graph.h" +#include "../support/scalars.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" +#include "tvm/runtime/builtin_fp16.h" namespace tvm { namespace relay { @@ -61,8 +63,17 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { } // default annotations if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; + if ((expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as()) && + (expr->checked_type_.defined() || expr->span.defined())) { + doc << " /*"; + if (expr->checked_type_.defined()) { + doc << " ty=" << Print(expr->checked_type()); + } + if (expr->span.defined()) { + doc << " span=" << PrintSpan(expr->span); + } + doc << " */"; } } else { std::string annotated_expr = annotate_(expr); @@ -219,7 +230,7 @@ Doc RelayTextPrinter::AllocVar(const Var& var) { name = "v" + name; } Doc val = GetUniqueName("%" + name); - memo_[var] = val; + memo_[var] = val; // Referential occurrences will not include the following. if (!var->virtual_device()->IsFullyUnconstrained()) { val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}"; } @@ -335,45 +346,10 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo // first time. Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef(op)); } -/*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ -template -Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { - std::ostringstream os; - if (dtype == DataType::Int(32)) { - os << value; - } else if (dtype == DataType::Float(32)) { - os << value << 'f'; - } else if (dtype == DataType::Float(64)) { - os << value << "f64"; - } else if (dtype == DataType::Bool()) { - return Doc::PyBoolLiteral(value != 0); - } else { - os << value; - } - return Doc::Text(os.str()); -} - Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { // Print out simple scalars directly. if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = DataType(op->data->dtype); - ICHECK_EQ(op->data->device.device_type, kDLCPU); - if (dtype == DataType::Int(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Int(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Bool()) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } + return Doc::Text(support::NDArrayScalarToString(op->data)); } // default fall-back, record it as meta node. Doc doc; @@ -540,9 +516,6 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { return doc; } else { doc << "(" << Doc::Concat(args) << ")"; - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } return doc; } } @@ -799,11 +772,11 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { } Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { - return ScalarLiteral(op->dtype, op->value); + return Doc::Text(support::IntImmToString(GetRef(op))); } Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { - return ScalarLiteral(op->dtype, op->value); + return Doc::Text(support::FloatImmToString(GetRef(op))); } Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { @@ -977,7 +950,7 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) { Doc doc; const auto* span_node = span.as(); ICHECK(span_node); - doc << span_node->source_name->name; + doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c34c4a5b6dbe..05a00e3305e1 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -152,13 +152,6 @@ class RelayTextPrinter : public ExprFunctor, // Should only be triggered when op is a free variable being visited for the // first time. Doc VisitExpr_(const VarNode* op) final; - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ - template - static Doc ScalarLiteral(DataType dtype, const T& value); Doc VisitExpr_(const ConstantNode* op) final; Doc VisitExpr_(const TupleNode* op) final; Doc VisitExpr_(const TupleGetItemNode* op) final; diff --git a/src/support/scalars.cc b/src/support/scalars.cc new file mode 100644 index 000000000000..10a530c76b41 --- /dev/null +++ b/src/support/scalars.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/support/scalars.cc + * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + */ + +#include "./scalars.h" + +#include "tvm/relay/expr.h" +#include "tvm/runtime/builtin_fp16.h" + +namespace tvm { +namespace support { + +static const DataType kInt16 = DataType::Int(16); +static const DataType kInt32 = DataType::Int(32); +static const DataType kInt64 = DataType::Int(64); +static const DataType kFloat16 = DataType::Float(16); +static const DataType kFloat32 = DataType::Float(32); +static const DataType kFloat64 = DataType::Float(64); +static const DataType kBool = DataType::Bool(); + +bool IsSimpleScalar(const relay::ConstantNode* constant_node) { + if (!constant_node->is_scalar()) { + return false; + } + DataType dtype(constant_node->data->dtype); + return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 || + dtype == kFloat32 || dtype == kFloat64 || dtype == kBool; +} + +runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); + if (int_imm.dtype() == kInt64) { + auto array = reinterpret_cast(data->data); + array[0] = int_imm->value; + } else if (int_imm.dtype() == kInt32) { + auto array = reinterpret_cast(data->data); + array[0] = static_cast(int_imm->value); + } else if (int_imm.dtype() == kInt16) { + auto array = reinterpret_cast(data->data); + array[0] = static_cast(int_imm->value); + } else { + LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype()); + } + return data; +} + +runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); + if (float_imm.dtype() == kFloat64) { + auto array = reinterpret_cast(data->data); + array[0] = float_imm->value; + } else if (float_imm.dtype() == kFloat32) { + auto array = reinterpret_cast(data->data); + array[0] = static_cast(float_imm->value); + } else if (float_imm.dtype() == kFloat16) { + auto array = reinterpret_cast(data->data); + array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); + } else { + LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype()); + } + return data; +} + +runtime::NDArray BoolToNDArray(bool value) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, kBool, dev); + auto array = reinterpret_cast(data->data); + array[0] = value; + return data; +} + +std::string NDArrayScalarToString(const runtime::NDArray& data) { + std::ostringstream os; + DataType dtype(data->dtype); + ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed"; + if (dtype == kInt16) { + auto value = static_cast(data->data)[0]; + os << value << "i16"; + } else if (dtype == kInt32) { + auto value = static_cast(data->data)[0]; + os << value; + } else if (dtype == kInt64) { + auto value = static_cast(data->data)[0]; + os << value << "i64"; + } else if (dtype == kFloat16) { + auto value = __gnu_h2f_ieee(static_cast(data->data)[0]); + os << value << "f16"; + } else if (dtype == kFloat32) { + auto value = static_cast(data->data)[0]; + os << value << "f"; + } else if (dtype == kFloat64) { + auto value = static_cast(data->data)[0]; + os << value << "f64"; + } else if (dtype == kBool) { + auto value = static_cast(data->data)[0]; + os << (value ? "True" : "False"); + } else { + LOG(FATAL) << "Unrecognized NDArray scalar dtype: " << DLDataType2String(dtype); + } + return os.str(); +} + +std::string IntImmToString(const IntImm& int_imm) { + std::ostringstream os; + if (int_imm->dtype == kInt16) { + os << int_imm->value << "i16"; + } else if (int_imm->dtype == kInt32) { + os << int_imm->value; + } else if (int_imm->dtype == kInt64) { + os << int_imm->value << "i64"; + } else { + LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataType2String(int_imm->dtype); + } + return os.str(); +} + +std::string FloatImmToString(const FloatImm& float_imm) { + std::ostringstream os; + if (float_imm->dtype == kFloat16) { + os << float_imm->value << "f16"; + } else if (float_imm->dtype == kFloat32) { + os << float_imm->value << "f"; + } else if (float_imm->dtype == kFloat64) { + os << float_imm->value << "f64"; + } else { + LOG(FATAL) << "Unrecognised FloatImm dtype: " << DLDataType2String(float_imm->dtype); + } + return os.str(); +} + +std::string BoolToString(bool value) { return value ? "True" : "False"; } + +std::pair ValueToIntImm(int64_t value, int width) { + bool clipped = false; + if (width == 16) { + if (value < std::numeric_limits::min()) { + value = std::numeric_limits::min(); + clipped = true; + } + if (value > std::numeric_limits::max()) { + value = std::numeric_limits::max(); + clipped = true; + } + return {IntImm(kInt16, value), clipped}; + } else if (width == 32) { + if (value < std::numeric_limits::min()) { + value = std::numeric_limits::min(); + clipped = true; + } + if (value > std::numeric_limits::max()) { + value = std::numeric_limits::max(); + clipped = true; + } + return {IntImm(kInt32, value), clipped}; + } else if (width == 64) { + return {IntImm(kInt64, value), clipped}; + } else { + LOG(FATAL) << "Unrecognized int scalar width: " << width; + return {}; + } +} + +std::pair ValueToFloatImm(double value, int width) { + bool clipped = false; + if (width == 16) { + // TODO(mbs): Limits for fp16? + return {FloatImm(kFloat16, value), clipped}; + } else if (width == 32) { + if (!std::isinf(value) && value < -std::numeric_limits::max()) { + value = -std::numeric_limits::max(); + clipped = true; + } + if (!std::isinf(value) && value > std::numeric_limits::max()) { + value = std::numeric_limits::max(); + clipped = true; + } + return {FloatImm(kFloat32, value), clipped}; + } else if (width == 64) { + return {FloatImm(kFloat64, value), clipped}; + } else { + LOG(FATAL) << "Unrecognized float scalar width: " << width; + return {}; + } +} + +} // namespace support +} // namespace tvm \ No newline at end of file diff --git a/src/support/scalars.h b/src/support/scalars.h new file mode 100644 index 000000000000..e6f54836c2f8 --- /dev/null +++ b/src/support/scalars.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/support/scalars.h + * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + */ + +#ifndef TVM_SUPPORT_SCALARS_H_ +#define TVM_SUPPORT_SCALARS_H_ + +#include "tvm/ir/expr.h" +#include "tvm/relay/expr.h" +#include "tvm/runtime/ndarray.h" + +namespace tvm { +namespace support { + +/*! \brief Returns true if \p constant_node is a float/int/bool scalar. */ +bool IsSimpleScalar(const relay::ConstantNode* constant_node); + +/*! \brief Returns NDArray 'scalar' for given TIR immediate. */ +runtime::NDArray IntImmToNDArray(const IntImm& int_imm); +runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm); +runtime::NDArray BoolToNDArray(bool value); + +/*! \brief Returns Relay literal text for NDArray 'scalar'. */ +std::string NDArrayScalarToString(const runtime::NDArray& data); + +/*! \brief Returns Relay literal text for given TIR immediate. */ +std::string IntImmToString(const IntImm& int_imm); +std::string FloatImmToString(const FloatImm& float_imm); +std::string BoolToString(bool value); + +/*! + * \brief Returns TIR immediate for given value and width. Boolean will be true if value + * was clipped in order to stay within range for width. However: + * - we ignore underflow + * - we don't currently check for float16 limits. + */ +std::pair ValueToIntImm(int64_t value, int width); +std::pair ValueToFloatImm(double value, int width); + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_SCALARS_H_ \ No newline at end of file diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index fdbd3924ffb7..3bf5ef2a8110 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay import tvm.relay.testing -import pytest from numpy import isclose from typing import Union @@ -172,6 +172,22 @@ def test_int_literal(): assert get_scalar(parse_text("-05")) == -5 assert get_scalar(parse_text("9223372036854775807")) == 9223372036854775807 + assert get_scalar(parse_text("-42i")) == -42 + assert get_scalar(parse_text("-42i16")) == -42 + assert get_scalar(parse_text("-42i32")) == -42 + assert get_scalar(parse_text("-42i64")) == -42 + + assert_parses_as("-42i16", relay.const(-42, "int16")) + assert_parses_as("-42i32", relay.const(-42, "int32")) + assert_parses_as("-42i", relay.const(-42, "int32")) + assert_parses_as("-42", relay.const(-42, "int32")) + assert_parses_as("-42i64", relay.const(-42, "int64")) + assert_parses_as("2147483647", relay.const(2147483647, "int32")) + assert_parses_as("2147483648", relay.const(2147483648, "int64")) + + with pytest.raises(tvm.error.DiagnosticError): + parse_text("2147483648i32") + parse_text("32768i16") def test_float_literal(): assert get_scalar(parse_text("1.0f")) == 1.0 @@ -189,11 +205,24 @@ def test_float_literal(): assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0e-1) assert get_scalar(parse_text("1.0E+1f")) == 1.0e1 + assert get_scalar(parse_text("3f16")) == 3.0 + assert get_scalar(parse_text("3f32")) == 3.0 + + assert_parses_as("3f16", relay.const(3.0, "float16")) + assert_parses_as("3f32", relay.const(3.0, "float32")) + assert_parses_as("3f", relay.const(3.0, "float32")) + assert_parses_as("3f64", relay.const(3.0, "float64")) + + with pytest.raises(tvm.error.DiagnosticError): + parse_text("3.40283e+38f32") + def test_bool_literal(): assert get_scalar(parse_text("True")) == True assert get_scalar(parse_text("False")) == False + assert_parses_as("True", relay.const(True, "bool")) + def test_negative(): # need to handle parsing non-literal operations @@ -993,4 +1022,4 @@ def @main(%x: Tensor[(2, 3), float32]) { if __name__ == "__main__": import sys - pytest.main(sys.argv) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 54e0e4c7ca44..01d32676254f 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -46,17 +46,26 @@ def show(text): print("---------------------------") print(text) - -# Commented due to weird memory allocation error -# def test_large_graph(): -# x = relay.var("x", shape=(3, 2)) -# y = relay.var("y") -# one = relay.const(10e10, dtype="float32") -# z = relay.add(x, one) -# for i in range(int(9e5)): -# z = relay.add(z, one) -# f = relay.Function([x, y], z) -# show(astext(f)) +def assert_prints_as(expr, str): + assert astext(expr) == SEMVER + str + +def test_scalars(): + assert_prints_as(relay.const(42, "int16"), "42i16") + assert_prints_as(relay.const(42, "int32"), "42") + assert_prints_as(relay.const(42, "int64"), "42i64") + assert_prints_as(relay.const(3.0, "float16"), "3f16") + assert_prints_as(relay.const(3.0, "float32"), "3f") + assert_prints_as(relay.const(3.0, "float64"), "3f64") + +def test_large_graph(): + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + for i in range(int(9e4)): + z = relay.add(z, one) + f = relay.Function([x, y], z) + show(astext(f)) def test_func(): @@ -295,4 +304,7 @@ def test_slash_in_identifier(): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) From fc3cf0b56918f79f66e59749f1d52abb5ab3a29e Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 5 May 2022 14:58:03 -0700 Subject: [PATCH 2/7] - lints --- src/printer/doc.cc | 4 +--- src/support/scalars.cc | 2 +- src/support/scalars.h | 5 ++++- tests/python/relay/test_ir_parser.py | 1 + tests/python/relay/test_ir_text_printer.py | 3 +++ 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/printer/doc.cc b/src/printer/doc.cc index 10977d083b56..b06995fb1286 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -52,9 +52,7 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: - explicit DocText(std::string str) { - data_ = runtime::make_object(str); - } + explicit DocText(std::string str) { data_ = runtime::make_object(str); } TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode); }; diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 10a530c76b41..56dd9d394a3b 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -206,4 +206,4 @@ std::pair ValueToFloatImm(double value, int width) { } } // namespace support -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/support/scalars.h b/src/support/scalars.h index e6f54836c2f8..66924cd0a551 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -25,6 +25,9 @@ #ifndef TVM_SUPPORT_SCALARS_H_ #define TVM_SUPPORT_SCALARS_H_ +#include +#include + #include "tvm/ir/expr.h" #include "tvm/relay/expr.h" #include "tvm/runtime/ndarray.h" @@ -60,4 +63,4 @@ std::pair ValueToFloatImm(double value, int width); } // namespace support } // namespace tvm -#endif // TVM_SUPPORT_SCALARS_H_ \ No newline at end of file +#endif // TVM_SUPPORT_SCALARS_H_ diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 3bf5ef2a8110..4a5f7632f6c9 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -189,6 +189,7 @@ def test_int_literal(): parse_text("2147483648i32") parse_text("32768i16") + def test_float_literal(): assert get_scalar(parse_text("1.0f")) == 1.0 assert isclose(get_scalar(parse_text("1.56667f")), 1.56667) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 01d32676254f..60f611998649 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -46,9 +46,11 @@ def show(text): print("---------------------------") print(text) + def assert_prints_as(expr, str): assert astext(expr) == SEMVER + str + def test_scalars(): assert_prints_as(relay.const(42, "int16"), "42i16") assert_prints_as(relay.const(42, "int32"), "42") @@ -57,6 +59,7 @@ def test_scalars(): assert_prints_as(relay.const(3.0, "float32"), "3f") assert_prints_as(relay.const(3.0, "float64"), "3f64") + def test_large_graph(): x = relay.var("x", shape=(3, 2)) y = relay.var("y") From 322a0cc490fb55871c9a726a0252105170bd7656 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 5 May 2022 15:05:58 -0700 Subject: [PATCH 3/7] - fix message --- src/parser/tokenizer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index a7d022a41907..6d4559277e60 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -173,7 +173,7 @@ struct Tokenizer { } Token ParseNumber(bool is_pos, bool is_float, std::string number) { - ICHECK(number.size() > 0) << "an empty string is an invalid float"; + ICHECK(number.size() > 0) << "an empty string is an invalid number"; Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger); size_t suffix_pos = number.rfind(is_float ? 'f' : 'i'); From 76e61c545225687a55e3612cd4143d99afe3693b Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 6 May 2022 13:08:54 -0700 Subject: [PATCH 4/7] - preserve legacy print behavior for unrecognised dtypes. --- src/printer/relay_text_printer.cc | 21 +++++++++++---- src/support/scalars.cc | 34 ++++++++++++------------- src/support/scalars.h | 3 +++ tests/python/relay/test_target_hooks.py | 3 ++- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bc66a22a7026..35daf588fbeb 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -348,14 +348,15 @@ Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRefis_scalar()) { + if (support::IsSimpleScalar(op)) { return Doc::Text(support::NDArrayScalarToString(op->data)); } - // default fall-back, record it as meta node. + // Fallbock: record it as a meta node. Doc doc; // Don't append optional_info. Because the entry function is Print, // and it will append the optional_info afterwards. - return doc << PrintExpr(GetRef(op), true, false, false); + return doc << PrintExpr(GetRef(op), /*meta=*/true, /*try_inline=*/false, + /*optional_info=*/false); } Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { @@ -772,11 +773,21 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { } Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { - return Doc::Text(support::IntImmToString(GetRef(op))); + if (support::IsSimpleScalarDtype(op->dtype)) { + return Doc::Text(support::IntImmToString(GetRef(op))); + } else { + // Fallback: Print int64_t without width suffix. + return Doc::Text(std::to_string(op->value)); + } } Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { - return Doc::Text(support::FloatImmToString(GetRef(op))); + if (support::IsSimpleScalarDtype(op->dtype)) { + return Doc::Text(support::FloatImmToString(GetRef(op))); + } else { + // Fallbock: Print double without width suffix. + return Doc::Text(std::to_string(op->value)); + } } Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 56dd9d394a3b..51ac7c1da06d 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -38,27 +38,27 @@ static const DataType kFloat32 = DataType::Float(32); static const DataType kFloat64 = DataType::Float(64); static const DataType kBool = DataType::Bool(); -bool IsSimpleScalar(const relay::ConstantNode* constant_node) { - if (!constant_node->is_scalar()) { - return false; - } - DataType dtype(constant_node->data->dtype); +bool IsSimpleScalarDtype(DataType dtype) { return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 || dtype == kFloat32 || dtype == kFloat64 || dtype == kBool; } +bool IsSimpleScalar(const relay::ConstantNode* constant_node) { + return constant_node->is_scalar() && IsSimpleScalarDtype(DataType(constant_node->data->dtype)); +} + runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); - if (int_imm.dtype() == kInt64) { - auto array = reinterpret_cast(data->data); - array[0] = int_imm->value; + if (int_imm.dtype() == kInt16) { + auto array = reinterpret_cast(data->data); + array[0] = static_cast(int_imm->value); } else if (int_imm.dtype() == kInt32) { auto array = reinterpret_cast(data->data); array[0] = static_cast(int_imm->value); - } else if (int_imm.dtype() == kInt16) { - auto array = reinterpret_cast(data->data); - array[0] = static_cast(int_imm->value); + } else if (int_imm.dtype() == kInt64) { + auto array = reinterpret_cast(data->data); + array[0] = int_imm->value; } else { LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype()); } @@ -68,15 +68,15 @@ runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); - if (float_imm.dtype() == kFloat64) { - auto array = reinterpret_cast(data->data); - array[0] = float_imm->value; + if (float_imm.dtype() == kFloat16) { + auto array = reinterpret_cast(data->data); + array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); } else if (float_imm.dtype() == kFloat32) { auto array = reinterpret_cast(data->data); array[0] = static_cast(float_imm->value); - } else if (float_imm.dtype() == kFloat16) { - auto array = reinterpret_cast(data->data); - array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); + } else if (float_imm.dtype() == kFloat64) { + auto array = reinterpret_cast(data->data); + array[0] = float_imm->value; } else { LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype()); } diff --git a/src/support/scalars.h b/src/support/scalars.h index 66924cd0a551..b9ccc64f1867 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -35,6 +35,9 @@ namespace tvm { namespace support { +/*! \brief Returns true if a tensor of empty shape and given dtype is considered a Relay scalar. */ +bool IsSimpleScalarDtype(DataType dtype); + /*! \brief Returns true if \p constant_node is a float/int/bool scalar. */ bool IsSimpleScalar(const relay::ConstantNode* constant_node); diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 5856dc1e1c69..7086cf3cedc7 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -73,4 +73,5 @@ def test_runtime_module_generation(check_result): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_runtime_module_generation(check_aot_executor_result) From 82f3e4a0fe47fe028302b6ff85a02c80ce1fd4f0 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 6 May 2022 15:28:14 -0700 Subject: [PATCH 5/7] - bool is legit IntImm --- src/support/scalars.cc | 4 ++-- src/support/scalars.h | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 51ac7c1da06d..81c9093b86f0 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -130,6 +130,8 @@ std::string IntImmToString(const IntImm& int_imm) { os << int_imm->value; } else if (int_imm->dtype == kInt64) { os << int_imm->value << "i64"; + } else if (int_imm->dtype == kBool) { + os << (int_imm->value ? "True" : "False"); } else { LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataType2String(int_imm->dtype); } @@ -150,8 +152,6 @@ std::string FloatImmToString(const FloatImm& float_imm) { return os.str(); } -std::string BoolToString(bool value) { return value ? "True" : "False"; } - std::pair ValueToIntImm(int64_t value, int width) { bool clipped = false; if (width == 16) { diff --git a/src/support/scalars.h b/src/support/scalars.h index b9ccc64f1867..c405e5eec61e 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -52,7 +52,6 @@ std::string NDArrayScalarToString(const runtime::NDArray& data); /*! \brief Returns Relay literal text for given TIR immediate. */ std::string IntImmToString(const IntImm& int_imm); std::string FloatImmToString(const FloatImm& float_imm); -std::string BoolToString(bool value); /*! * \brief Returns TIR immediate for given value and width. Boolean will be true if value From bc46b550011010e158748d5f0f94cc13ea8f94d7 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Wed, 18 May 2022 17:10:22 -0700 Subject: [PATCH 6/7] - Chris' comments --- src/parser/tokenizer.h | 23 ++++---- src/support/scalars.cc | 70 +++++++++++-------------- src/support/scalars.h | 11 ++-- tests/cpp/support/scalars_test.cc | 60 +++++++++++++++++++++ tests/python/relay/test_ir_parser.py | 7 +++ tests/python/relay/test_target_hooks.py | 3 +- 6 files changed, 117 insertions(+), 57 deletions(-) create mode 100644 tests/cpp/support/scalars_test.cc diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 6d4559277e60..4ac1ceef26dc 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -216,12 +216,12 @@ struct Tokenizer { << "invalid floating point number `" << literal_text << "`"); } value = is_pos ? value : -value; - bool clipped; - std::tie(token->data, clipped) = support::ValueToFloatImm(value, width); - if (clipped) { + token->data = support::ValueToFloatImm(value, width); + if (!token->data.defined()) { this->diag_ctx.Emit(Diagnostic::Error(token->span) << "floating point number `" << literal_text - << "` out of range for width " << width); + << "` unrepresentable in width " << width); + token->data = support::ValueToFloatImm(0.0, width); } } else { int64_t value = 0; @@ -240,17 +240,18 @@ struct Tokenizer { << "invalid integer number `" << literal_text << "`"); } value = is_pos ? value : -value; - bool clipped; - std::tie(token->data, clipped) = support::ValueToIntImm(value, width); - if (clipped && suffix.empty()) { + token->data = support::ValueToIntImm(value, width); + if (!token->data.defined() && suffix.empty()) { // Without any i suffix the legacy behavior was to default to int64 if out of range // for int32. width = 64; - std::tie(token->data, clipped) = support::ValueToIntImm(value, width); + token->data = support::ValueToIntImm(value, width); } - if (clipped) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "integer number `" << literal_text - << "` out of range for width " << width); + if (!token->data.defined()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "integer number `" << literal_text << "` unrepresentable in width " + << width); + token->data = support::ValueToIntImm(0, width); } } diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 81c9093b86f0..b64840537548 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -30,6 +30,7 @@ namespace tvm { namespace support { +/*! \brief The standard scalar dtypes. */ static const DataType kInt16 = DataType::Int(16); static const DataType kInt32 = DataType::Int(32); static const DataType kInt64 = DataType::Int(64); @@ -51,13 +52,13 @@ runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); if (int_imm.dtype() == kInt16) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = static_cast(int_imm->value); } else if (int_imm.dtype() == kInt32) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = static_cast(int_imm->value); } else if (int_imm.dtype() == kInt64) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = int_imm->value; } else { LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype()); @@ -69,13 +70,13 @@ runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); if (float_imm.dtype() == kFloat16) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); } else if (float_imm.dtype() == kFloat32) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = static_cast(float_imm->value); } else if (float_imm.dtype() == kFloat64) { - auto array = reinterpret_cast(data->data); + auto* array = reinterpret_cast(data->data); array[0] = float_imm->value; } else { LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype()); @@ -152,53 +153,46 @@ std::string FloatImmToString(const FloatImm& float_imm) { return os.str(); } -std::pair ValueToIntImm(int64_t value, int width) { - bool clipped = false; +IntImm ValueToIntImm(int64_t value, int width) { if (width == 16) { - if (value < std::numeric_limits::min()) { - value = std::numeric_limits::min(); - clipped = true; + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + return {}; } - if (value > std::numeric_limits::max()) { - value = std::numeric_limits::max(); - clipped = true; - } - return {IntImm(kInt16, value), clipped}; + return IntImm(kInt16, value); } else if (width == 32) { - if (value < std::numeric_limits::min()) { - value = std::numeric_limits::min(); - clipped = true; - } - if (value > std::numeric_limits::max()) { - value = std::numeric_limits::max(); - clipped = true; + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + return {}; } - return {IntImm(kInt32, value), clipped}; + return IntImm(kInt32, value); } else if (width == 64) { - return {IntImm(kInt64, value), clipped}; + return IntImm(kInt64, value); } else { LOG(FATAL) << "Unrecognized int scalar width: " << width; return {}; } } -std::pair ValueToFloatImm(double value, int width) { - bool clipped = false; +// 2^15 * (1 + 1023/1024) +// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format +constexpr double kMaxFloat16 = 65504.0; + +FloatImm ValueToFloatImm(double value, int width) { if (width == 16) { - // TODO(mbs): Limits for fp16? - return {FloatImm(kFloat16, value), clipped}; - } else if (width == 32) { - if (!std::isinf(value) && value < -std::numeric_limits::max()) { - value = -std::numeric_limits::max(); - clipped = true; + if (!std::isinf(value) && + (value < -kMaxFloat16 || value > kMaxFloat16)) { + return {}; } - if (!std::isinf(value) && value > std::numeric_limits::max()) { - value = std::numeric_limits::max(); - clipped = true; + return FloatImm(kFloat16, value); + } else if (width == 32) { + if (!std::isinf(value) && + (value < -std::numeric_limits::max() || value > std::numeric_limits::max())) { + return {}; } - return {FloatImm(kFloat32, value), clipped}; + return FloatImm(kFloat32, value); } else if (width == 64) { - return {FloatImm(kFloat64, value), clipped}; + return FloatImm(kFloat64, value); } else { LOG(FATAL) << "Unrecognized float scalar width: " << width; return {}; diff --git a/src/support/scalars.h b/src/support/scalars.h index c405e5eec61e..60b8fc40a8de 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -54,13 +54,12 @@ std::string IntImmToString(const IntImm& int_imm); std::string FloatImmToString(const FloatImm& float_imm); /*! - * \brief Returns TIR immediate for given value and width. Boolean will be true if value - * was clipped in order to stay within range for width. However: - * - we ignore underflow - * - we don't currently check for float16 limits. + * \brief Returns TIR immediate for given value and width. Result will be null if value is + * out of range in width. Note however for floating point we don't check if the value is + * representable without loss of precision. */ -std::pair ValueToIntImm(int64_t value, int width); -std::pair ValueToFloatImm(double value, int width); +IntImm ValueToIntImm(int64_t value, int width); +FloatImm ValueToFloatImm(double value, int width); } // namespace support } // namespace tvm diff --git a/tests/cpp/support/scalars_test.cc b/tests/cpp/support/scalars_test.cc new file mode 100644 index 000000000000..491ba48847c9 --- /dev/null +++ b/tests/cpp/support/scalars_test.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/support/scalars.h" + +#include +#include + +namespace tvm { +namespace support { +namespace { + +TEST(Scalars, IntImmToNDArray_Unsupported) { + ASSERT_THROW(IntImmToNDArray(IntImm(DataType::Int(15), 42)), runtime::InternalError); +} + +TEST(Scalars, FloatImmtoNDArray_Unsupported) { + ASSERT_THROW(FloatImmToNDArray(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); +} + +TEST(Scalars, NDArrayScalarToString_Unsupported) { + auto ndarray = runtime::NDArray::Empty({}, DataType::Int(8), {DLDeviceType::kDLCPU, 0}); + ASSERT_THROW(NDArrayScalarToString(ndarray), runtime::InternalError); +} + +TEST(Scalars, IntImmToString_Unsupported) { + ASSERT_THROW(IntImmToString(IntImm(DataType::Int(15), 42)), runtime::InternalError); +} + +TEST(Scalars, FloatImmToString_Unsupported) { + ASSERT_THROW(FloatImmToString(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); +} + +TEST(Scalars, ValueToIntImm_Unsupported) { + ASSERT_THROW(ValueToIntImm(42, 15), runtime::InternalError); +} + +TEST(SCalars, ValueToFloatImm_Unsupported) { + ASSERT_THROW(ValueToFloatImm(42.0, 15), runtime::InternalError); +} + +} // namespace +} // namespace support +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 4a5f7632f6c9..7a283461e0bd 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -186,7 +186,10 @@ def test_int_literal(): assert_parses_as("2147483648", relay.const(2147483648, "int64")) with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable parse_text("2147483648i32") + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable parse_text("32768i16") @@ -215,7 +218,11 @@ def test_float_literal(): assert_parses_as("3f64", relay.const(3.0, "float64")) with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable parse_text("3.40283e+38f32") + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable + parse_text("65505f16") def test_bool_literal(): diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 7086cf3cedc7..5856dc1e1c69 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -73,5 +73,4 @@ def test_runtime_module_generation(check_result): if __name__ == "__main__": - # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_runtime_module_generation(check_aot_executor_result) + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 86d34348170734f791868a5ca112d97b671292e5 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Wed, 18 May 2022 17:19:33 -0700 Subject: [PATCH 7/7] - Lints --- src/support/scalars.cc | 3 +-- tests/cpp/support/scalars_test.cc | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/support/scalars.cc b/src/support/scalars.cc index b64840537548..9caa7ca58915 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -180,8 +180,7 @@ constexpr double kMaxFloat16 = 65504.0; FloatImm ValueToFloatImm(double value, int width) { if (width == 16) { - if (!std::isinf(value) && - (value < -kMaxFloat16 || value > kMaxFloat16)) { + if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) { return {}; } return FloatImm(kFloat16, value); diff --git a/tests/cpp/support/scalars_test.cc b/tests/cpp/support/scalars_test.cc index 491ba48847c9..d55f0541fa40 100644 --- a/tests/cpp/support/scalars_test.cc +++ b/tests/cpp/support/scalars_test.cc @@ -26,6 +26,9 @@ namespace tvm { namespace support { namespace { +// Note that functional testing is via test_ir_parser.py and test_ir_text_printer.py. +// Here we just check handling which is difficult to test via the standard Python API. + TEST(Scalars, IntImmToNDArray_Unsupported) { ASSERT_THROW(IntImmToNDArray(IntImm(DataType::Int(15), 42)), runtime::InternalError); } @@ -57,4 +60,4 @@ TEST(SCalars, ValueToFloatImm_Unsupported) { } // namespace } // namespace support -} // namespace tvm \ No newline at end of file +} // namespace tvm