diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 9b15893092f7..f51e3e5c9737 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -35,10 +35,12 @@ #include +#include "../support/scalars.h" #include "./meta_ref.h" #include "./op_table.h" #include "./span_check.h" #include "./tokenizer.h" +#include "tvm/runtime/builtin_fp16.h" namespace tvm { namespace parser { @@ -534,49 +536,15 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { if (token->token_type == TokenType::kInteger) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - int64_t i = Downcast(token->data); - if (i > std::numeric_limits::max()) { - auto dtype = String2DLDataType("int64"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - array[0] = i; - return data; - } else { - auto dtype = String2DLDataType("int32"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - array[0] = i; - return data; - } + return support::IntImmToNDArray(Downcast(token->data)); } else if (token->token_type == TokenType::kFloat) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto float_imm = Downcast(token->data); - auto data = NDArray::Empty({}, float_imm->dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - // TODO(@jroesch): bounds checking - float value = float_imm->value; - array[0] = value; - return data; + return support::FloatImmToNDArray(Downcast(token->data)); } else { LOG(FATAL) << "internal error: should only call this function on numeric tokens"; - return NDArray(); + return {}; } } - /*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */ - NDArray BooleanToNDarray(bool value) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto dtype = String2DLDataType("bool"); - auto data = NDArray::Empty({}, dtype, dev); - auto array = reinterpret_cast(data->data); - array[0] = value; - return data; - } - [[noreturn]] void ParseError(const Token& token, const std::string& msg) { throw std::runtime_error(msg); } @@ -1573,8 +1541,7 @@ class Parser { case TokenType::kBoolean: { Consume(TokenType::kBoolean); int64_t value = Downcast(next->data); - auto boolean = BooleanToNDarray(value); - Expr e = Constant(boolean, next->span); + Expr e = Constant(support::BoolToNDArray(value), next->span); ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f8098cf94100..4ac1ceef26dc 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -34,6 +34,7 @@ #include #include +#include "../support/scalars.h" #include "./meta_ref.h" #include "./token.h" @@ -174,35 +175,16 @@ struct Tokenizer { Token ParseNumber(bool is_pos, bool is_float, std::string number) { ICHECK(number.size() > 0) << "an empty string is an invalid number"; - if (!is_float) { - auto token = NewToken(TokenType::kInteger); - size_t index = 0; - int64_t value = 0; - try { - value = std::stoll(number, &index); - } catch (const std::invalid_argument& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } catch (const std::out_of_range& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } - if (number.size() <= index) { - value = is_pos ? value : -value; - if (value > std::numeric_limits::max()) { - token->data = tvm::IntImm(DataType::Int(64), value); - } else { - token->data = tvm::IntImm(DataType::Int(32), value); - } - return token; - } + Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger); + size_t suffix_pos = number.rfind(is_float ? 'f' : 'i'); + if (suffix_pos == std::string::npos) { + suffix_pos = number.size(); + } + std::string literal_text = number.substr(0, suffix_pos); + std::string suffix; + if (suffix_pos < number.size()) { + suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); } - auto token = NewToken(TokenType::kFloat); - - auto suffix_pos = number.rfind("f"); - - auto literal_text = number.substr(0, suffix_pos); - - auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); - int width = 32; if (suffix.size()) { @@ -217,9 +199,62 @@ struct Tokenizer { } } - double value = stod(literal_text); - value = is_pos ? value : -value; - token->data = tvm::FloatImm(DataType::Float(width), value); + if (is_float) { + double value = 0.0; + size_t index = 0; + try { + value = stod(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + token->data = support::ValueToFloatImm(value, width); + if (!token->data.defined()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "floating point number `" << literal_text + << "` unrepresentable in width " << width); + token->data = support::ValueToFloatImm(0.0, width); + } + } else { + int64_t value = 0; + size_t index = 0; + try { + value = std::stoll(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + token->data = support::ValueToIntImm(value, width); + if (!token->data.defined() && suffix.empty()) { + // Without any i suffix the legacy behavior was to default to int64 if out of range + // for int32. + width = 64; + token->data = support::ValueToIntImm(value, width); + } + if (!token->data.defined()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "integer number `" << literal_text << "` unrepresentable in width " + << width); + token->data = support::ValueToIntImm(0, width); + } + } + return token; } @@ -230,14 +265,13 @@ struct Tokenizer { } bool is_float = false; - - // Remove trailing floating point prefix. - if (More() && Peek() == 'f') { + if (More() && (Peek() == 'f' || Peek() == 'i')) { + is_float = Peek() == 'f'; + // Capture trailing width suffix ss << Next(); while (More() && IsNumeric(Peek())) { ss << Next(); } - is_float = true; } return ParseNumber(is_pos, is_float, ss.str()); } diff --git a/src/printer/doc.cc b/src/printer/doc.cc index f7d9fdfd7dfb..b06995fb1286 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -52,12 +52,7 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: - explicit DocText(std::string str) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not have tab or newline."; - } - data_ = runtime::make_object(str); - } + explicit DocText(std::string str) { data_ = runtime::make_object(str); } TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode); }; diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 97231931ad88..35daf588fbeb 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -43,9 +43,11 @@ #include "../ir/attr_functor.h" #include "../parser/meta_ref.h" #include "../relay/analysis/dependency_graph.h" +#include "../support/scalars.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" +#include "tvm/runtime/builtin_fp16.h" namespace tvm { namespace relay { @@ -61,8 +63,17 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { } // default annotations if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; + if ((expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as()) && + (expr->checked_type_.defined() || expr->span.defined())) { + doc << " /*"; + if (expr->checked_type_.defined()) { + doc << " ty=" << Print(expr->checked_type()); + } + if (expr->span.defined()) { + doc << " span=" << PrintSpan(expr->span); + } + doc << " */"; } } else { std::string annotated_expr = annotate_(expr); @@ -219,7 +230,7 @@ Doc RelayTextPrinter::AllocVar(const Var& var) { name = "v" + name; } Doc val = GetUniqueName("%" + name); - memo_[var] = val; + memo_[var] = val; // Referential occurrences will not include the following. if (!var->virtual_device()->IsFullyUnconstrained()) { val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}"; } @@ -335,51 +346,17 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo // first time. Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef(op)); } -/*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ -template -Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { - std::ostringstream os; - if (dtype == DataType::Int(32)) { - os << value; - } else if (dtype == DataType::Float(32)) { - os << value << 'f'; - } else if (dtype == DataType::Float(64)) { - os << value << "f64"; - } else if (dtype == DataType::Bool()) { - return Doc::PyBoolLiteral(value != 0); - } else { - os << value; - } - return Doc::Text(os.str()); -} - Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = DataType(op->data->dtype); - ICHECK_EQ(op->data->device.device_type, kDLCPU); - if (dtype == DataType::Int(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Int(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Bool()) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } + if (support::IsSimpleScalar(op)) { + return Doc::Text(support::NDArrayScalarToString(op->data)); } - // default fall-back, record it as meta node. + // Fallbock: record it as a meta node. Doc doc; // Don't append optional_info. Because the entry function is Print, // and it will append the optional_info afterwards. - return doc << PrintExpr(GetRef(op), true, false, false); + return doc << PrintExpr(GetRef(op), /*meta=*/true, /*try_inline=*/false, + /*optional_info=*/false); } Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { @@ -540,9 +517,6 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { return doc; } else { doc << "(" << Doc::Concat(args) << ")"; - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } return doc; } } @@ -799,11 +773,21 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { } Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { - return ScalarLiteral(op->dtype, op->value); + if (support::IsSimpleScalarDtype(op->dtype)) { + return Doc::Text(support::IntImmToString(GetRef(op))); + } else { + // Fallback: Print int64_t without width suffix. + return Doc::Text(std::to_string(op->value)); + } } Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { - return ScalarLiteral(op->dtype, op->value); + if (support::IsSimpleScalarDtype(op->dtype)) { + return Doc::Text(support::FloatImmToString(GetRef(op))); + } else { + // Fallbock: Print double without width suffix. + return Doc::Text(std::to_string(op->value)); + } } Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { @@ -977,7 +961,7 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) { Doc doc; const auto* span_node = span.as(); ICHECK(span_node); - doc << span_node->source_name->name; + doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c34c4a5b6dbe..05a00e3305e1 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -152,13 +152,6 @@ class RelayTextPrinter : public ExprFunctor, // Should only be triggered when op is a free variable being visited for the // first time. Doc VisitExpr_(const VarNode* op) final; - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ - template - static Doc ScalarLiteral(DataType dtype, const T& value); Doc VisitExpr_(const ConstantNode* op) final; Doc VisitExpr_(const TupleNode* op) final; Doc VisitExpr_(const TupleGetItemNode* op) final; diff --git a/src/support/scalars.cc b/src/support/scalars.cc new file mode 100644 index 000000000000..9caa7ca58915 --- /dev/null +++ b/src/support/scalars.cc @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/support/scalars.cc + * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + */ + +#include "./scalars.h" + +#include "tvm/relay/expr.h" +#include "tvm/runtime/builtin_fp16.h" + +namespace tvm { +namespace support { + +/*! \brief The standard scalar dtypes. */ +static const DataType kInt16 = DataType::Int(16); +static const DataType kInt32 = DataType::Int(32); +static const DataType kInt64 = DataType::Int(64); +static const DataType kFloat16 = DataType::Float(16); +static const DataType kFloat32 = DataType::Float(32); +static const DataType kFloat64 = DataType::Float(64); +static const DataType kBool = DataType::Bool(); + +bool IsSimpleScalarDtype(DataType dtype) { + return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 || + dtype == kFloat32 || dtype == kFloat64 || dtype == kBool; +} + +bool IsSimpleScalar(const relay::ConstantNode* constant_node) { + return constant_node->is_scalar() && IsSimpleScalarDtype(DataType(constant_node->data->dtype)); +} + +runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); + if (int_imm.dtype() == kInt16) { + auto* array = reinterpret_cast(data->data); + array[0] = static_cast(int_imm->value); + } else if (int_imm.dtype() == kInt32) { + auto* array = reinterpret_cast(data->data); + array[0] = static_cast(int_imm->value); + } else if (int_imm.dtype() == kInt64) { + auto* array = reinterpret_cast(data->data); + array[0] = int_imm->value; + } else { + LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype()); + } + return data; +} + +runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); + if (float_imm.dtype() == kFloat16) { + auto* array = reinterpret_cast(data->data); + array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); + } else if (float_imm.dtype() == kFloat32) { + auto* array = reinterpret_cast(data->data); + array[0] = static_cast(float_imm->value); + } else if (float_imm.dtype() == kFloat64) { + auto* array = reinterpret_cast(data->data); + array[0] = float_imm->value; + } else { + LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype()); + } + return data; +} + +runtime::NDArray BoolToNDArray(bool value) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; + auto data = runtime::NDArray::Empty({}, kBool, dev); + auto array = reinterpret_cast(data->data); + array[0] = value; + return data; +} + +std::string NDArrayScalarToString(const runtime::NDArray& data) { + std::ostringstream os; + DataType dtype(data->dtype); + ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed"; + if (dtype == kInt16) { + auto value = static_cast(data->data)[0]; + os << value << "i16"; + } else if (dtype == kInt32) { + auto value = static_cast(data->data)[0]; + os << value; + } else if (dtype == kInt64) { + auto value = static_cast(data->data)[0]; + os << value << "i64"; + } else if (dtype == kFloat16) { + auto value = __gnu_h2f_ieee(static_cast(data->data)[0]); + os << value << "f16"; + } else if (dtype == kFloat32) { + auto value = static_cast(data->data)[0]; + os << value << "f"; + } else if (dtype == kFloat64) { + auto value = static_cast(data->data)[0]; + os << value << "f64"; + } else if (dtype == kBool) { + auto value = static_cast(data->data)[0]; + os << (value ? "True" : "False"); + } else { + LOG(FATAL) << "Unrecognized NDArray scalar dtype: " << DLDataType2String(dtype); + } + return os.str(); +} + +std::string IntImmToString(const IntImm& int_imm) { + std::ostringstream os; + if (int_imm->dtype == kInt16) { + os << int_imm->value << "i16"; + } else if (int_imm->dtype == kInt32) { + os << int_imm->value; + } else if (int_imm->dtype == kInt64) { + os << int_imm->value << "i64"; + } else if (int_imm->dtype == kBool) { + os << (int_imm->value ? "True" : "False"); + } else { + LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataType2String(int_imm->dtype); + } + return os.str(); +} + +std::string FloatImmToString(const FloatImm& float_imm) { + std::ostringstream os; + if (float_imm->dtype == kFloat16) { + os << float_imm->value << "f16"; + } else if (float_imm->dtype == kFloat32) { + os << float_imm->value << "f"; + } else if (float_imm->dtype == kFloat64) { + os << float_imm->value << "f64"; + } else { + LOG(FATAL) << "Unrecognised FloatImm dtype: " << DLDataType2String(float_imm->dtype); + } + return os.str(); +} + +IntImm ValueToIntImm(int64_t value, int width) { + if (width == 16) { + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + return {}; + } + return IntImm(kInt16, value); + } else if (width == 32) { + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + return {}; + } + return IntImm(kInt32, value); + } else if (width == 64) { + return IntImm(kInt64, value); + } else { + LOG(FATAL) << "Unrecognized int scalar width: " << width; + return {}; + } +} + +// 2^15 * (1 + 1023/1024) +// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format +constexpr double kMaxFloat16 = 65504.0; + +FloatImm ValueToFloatImm(double value, int width) { + if (width == 16) { + if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) { + return {}; + } + return FloatImm(kFloat16, value); + } else if (width == 32) { + if (!std::isinf(value) && + (value < -std::numeric_limits::max() || value > std::numeric_limits::max())) { + return {}; + } + return FloatImm(kFloat32, value); + } else if (width == 64) { + return FloatImm(kFloat64, value); + } else { + LOG(FATAL) << "Unrecognized float scalar width: " << width; + return {}; + } +} + +} // namespace support +} // namespace tvm diff --git a/src/support/scalars.h b/src/support/scalars.h new file mode 100644 index 000000000000..60b8fc40a8de --- /dev/null +++ b/src/support/scalars.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/support/scalars.h + * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + */ + +#ifndef TVM_SUPPORT_SCALARS_H_ +#define TVM_SUPPORT_SCALARS_H_ + +#include +#include + +#include "tvm/ir/expr.h" +#include "tvm/relay/expr.h" +#include "tvm/runtime/ndarray.h" + +namespace tvm { +namespace support { + +/*! \brief Returns true if a tensor of empty shape and given dtype is considered a Relay scalar. */ +bool IsSimpleScalarDtype(DataType dtype); + +/*! \brief Returns true if \p constant_node is a float/int/bool scalar. */ +bool IsSimpleScalar(const relay::ConstantNode* constant_node); + +/*! \brief Returns NDArray 'scalar' for given TIR immediate. */ +runtime::NDArray IntImmToNDArray(const IntImm& int_imm); +runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm); +runtime::NDArray BoolToNDArray(bool value); + +/*! \brief Returns Relay literal text for NDArray 'scalar'. */ +std::string NDArrayScalarToString(const runtime::NDArray& data); + +/*! \brief Returns Relay literal text for given TIR immediate. */ +std::string IntImmToString(const IntImm& int_imm); +std::string FloatImmToString(const FloatImm& float_imm); + +/*! + * \brief Returns TIR immediate for given value and width. Result will be null if value is + * out of range in width. Note however for floating point we don't check if the value is + * representable without loss of precision. + */ +IntImm ValueToIntImm(int64_t value, int width); +FloatImm ValueToFloatImm(double value, int width); + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_SCALARS_H_ diff --git a/tests/cpp/support/scalars_test.cc b/tests/cpp/support/scalars_test.cc new file mode 100644 index 000000000000..d55f0541fa40 --- /dev/null +++ b/tests/cpp/support/scalars_test.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/support/scalars.h" + +#include +#include + +namespace tvm { +namespace support { +namespace { + +// Note that functional testing is via test_ir_parser.py and test_ir_text_printer.py. +// Here we just check handling which is difficult to test via the standard Python API. + +TEST(Scalars, IntImmToNDArray_Unsupported) { + ASSERT_THROW(IntImmToNDArray(IntImm(DataType::Int(15), 42)), runtime::InternalError); +} + +TEST(Scalars, FloatImmtoNDArray_Unsupported) { + ASSERT_THROW(FloatImmToNDArray(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); +} + +TEST(Scalars, NDArrayScalarToString_Unsupported) { + auto ndarray = runtime::NDArray::Empty({}, DataType::Int(8), {DLDeviceType::kDLCPU, 0}); + ASSERT_THROW(NDArrayScalarToString(ndarray), runtime::InternalError); +} + +TEST(Scalars, IntImmToString_Unsupported) { + ASSERT_THROW(IntImmToString(IntImm(DataType::Int(15), 42)), runtime::InternalError); +} + +TEST(Scalars, FloatImmToString_Unsupported) { + ASSERT_THROW(FloatImmToString(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); +} + +TEST(Scalars, ValueToIntImm_Unsupported) { + ASSERT_THROW(ValueToIntImm(42, 15), runtime::InternalError); +} + +TEST(SCalars, ValueToFloatImm_Unsupported) { + ASSERT_THROW(ValueToFloatImm(42.0, 15), runtime::InternalError); +} + +} // namespace +} // namespace support +} // namespace tvm diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index fdbd3924ffb7..7a283461e0bd 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay import tvm.relay.testing -import pytest from numpy import isclose from typing import Union @@ -172,6 +172,26 @@ def test_int_literal(): assert get_scalar(parse_text("-05")) == -5 assert get_scalar(parse_text("9223372036854775807")) == 9223372036854775807 + assert get_scalar(parse_text("-42i")) == -42 + assert get_scalar(parse_text("-42i16")) == -42 + assert get_scalar(parse_text("-42i32")) == -42 + assert get_scalar(parse_text("-42i64")) == -42 + + assert_parses_as("-42i16", relay.const(-42, "int16")) + assert_parses_as("-42i32", relay.const(-42, "int32")) + assert_parses_as("-42i", relay.const(-42, "int32")) + assert_parses_as("-42", relay.const(-42, "int32")) + assert_parses_as("-42i64", relay.const(-42, "int64")) + assert_parses_as("2147483647", relay.const(2147483647, "int32")) + assert_parses_as("2147483648", relay.const(2147483648, "int64")) + + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable + parse_text("2147483648i32") + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable + parse_text("32768i16") + def test_float_literal(): assert get_scalar(parse_text("1.0f")) == 1.0 @@ -189,11 +209,28 @@ def test_float_literal(): assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0e-1) assert get_scalar(parse_text("1.0E+1f")) == 1.0e1 + assert get_scalar(parse_text("3f16")) == 3.0 + assert get_scalar(parse_text("3f32")) == 3.0 + + assert_parses_as("3f16", relay.const(3.0, "float16")) + assert_parses_as("3f32", relay.const(3.0, "float32")) + assert_parses_as("3f", relay.const(3.0, "float32")) + assert_parses_as("3f64", relay.const(3.0, "float64")) + + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable + parse_text("3.40283e+38f32") + with pytest.raises(tvm.error.DiagnosticError): + # Unrepresentable + parse_text("65505f16") + def test_bool_literal(): assert get_scalar(parse_text("True")) == True assert get_scalar(parse_text("False")) == False + assert_parses_as("True", relay.const(True, "bool")) + def test_negative(): # need to handle parsing non-literal operations @@ -993,4 +1030,4 @@ def @main(%x: Tensor[(2, 3), float32]) { if __name__ == "__main__": import sys - pytest.main(sys.argv) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 54e0e4c7ca44..60f611998649 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -47,16 +47,28 @@ def show(text): print(text) -# Commented due to weird memory allocation error -# def test_large_graph(): -# x = relay.var("x", shape=(3, 2)) -# y = relay.var("y") -# one = relay.const(10e10, dtype="float32") -# z = relay.add(x, one) -# for i in range(int(9e5)): -# z = relay.add(z, one) -# f = relay.Function([x, y], z) -# show(astext(f)) +def assert_prints_as(expr, str): + assert astext(expr) == SEMVER + str + + +def test_scalars(): + assert_prints_as(relay.const(42, "int16"), "42i16") + assert_prints_as(relay.const(42, "int32"), "42") + assert_prints_as(relay.const(42, "int64"), "42i64") + assert_prints_as(relay.const(3.0, "float16"), "3f16") + assert_prints_as(relay.const(3.0, "float32"), "3f") + assert_prints_as(relay.const(3.0, "float64"), "3f64") + + +def test_large_graph(): + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + for i in range(int(9e4)): + z = relay.add(z, one) + f = relay.Function([x, y], z) + show(astext(f)) def test_func(): @@ -295,4 +307,7 @@ def test_slash_in_identifier(): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))