Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Support i16, f16 scalars in Relay text #11224

Merged
merged 7 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 6 additions & 39 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@

#include <fstream>

#include "../support/scalars.h"
#include "./meta_ref.h"
#include "./op_table.h"
#include "./span_check.h"
#include "./tokenizer.h"
#include "tvm/runtime/builtin_fp16.h"

namespace tvm {
namespace parser {
Expand Down Expand Up @@ -534,49 +536,15 @@ class Parser {
/*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
NDArray NumberToNDArray(const Token& token) {
if (token->token_type == TokenType::kInteger) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
int64_t i = Downcast<tvm::Integer>(token->data);
if (i > std::numeric_limits<int32_t>::max()) {
auto dtype = String2DLDataType("int64");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int64_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
} else {
auto dtype = String2DLDataType("int32");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int32_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
}
return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data));
} else if (token->token_type == TokenType::kFloat) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto float_imm = Downcast<tvm::FloatImm>(token->data);
auto data = NDArray::Empty({}, float_imm->dtype, dev);
auto array = reinterpret_cast<float*>(data->data);
// revisit this, literal node issue.
// TODO(@jroesch): bounds checking
float value = float_imm->value;
array[0] = value;
return data;
return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data));
} else {
LOG(FATAL) << "internal error: should only call this function on numeric tokens";
return NDArray();
return {};
}
}

/*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */
NDArray BooleanToNDarray(bool value) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto dtype = String2DLDataType("bool");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<bool*>(data->data);
array[0] = value;
return data;
}

[[noreturn]] void ParseError(const Token& token, const std::string& msg) {
throw std::runtime_error(msg);
}
Expand Down Expand Up @@ -1573,8 +1541,7 @@ class Parser {
case TokenType::kBoolean: {
Consume(TokenType::kBoolean);
int64_t value = Downcast<tvm::Integer>(next->data);
auto boolean = BooleanToNDarray(value);
Expr e = Constant(boolean, next->span);
Expr e = Constant(support::BoolToNDArray(value), next->span);
ICHECK(e->span.defined()) << "constant spans must be defined";
return e;
}
Expand Down
104 changes: 69 additions & 35 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <utility>
#include <vector>

#include "../support/scalars.h"
#include "./meta_ref.h"
#include "./token.h"

Expand Down Expand Up @@ -174,35 +175,16 @@ struct Tokenizer {
Token ParseNumber(bool is_pos, bool is_float, std::string number) {
ICHECK(number.size() > 0) << "an empty string is an invalid number";

if (!is_float) {
auto token = NewToken(TokenType::kInteger);
size_t index = 0;
int64_t value = 0;
try {
value = std::stoll(number, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
}
if (number.size() <= index) {
value = is_pos ? value : -value;
if (value > std::numeric_limits<int32_t>::max()) {
token->data = tvm::IntImm(DataType::Int(64), value);
} else {
token->data = tvm::IntImm(DataType::Int(32), value);
}
return token;
}
Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger);
size_t suffix_pos = number.rfind(is_float ? 'f' : 'i');
if (suffix_pos == std::string::npos) {
suffix_pos = number.size();
}
std::string literal_text = number.substr(0, suffix_pos);
std::string suffix;
if (suffix_pos < number.size()) {
suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);
}
auto token = NewToken(TokenType::kFloat);

auto suffix_pos = number.rfind("f");

auto literal_text = number.substr(0, suffix_pos);

auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);

int width = 32;

if (suffix.size()) {
Expand All @@ -217,9 +199,62 @@ struct Tokenizer {
}
}

double value = stod(literal_text);
value = is_pos ? value : -value;
token->data = tvm::FloatImm(DataType::Float(width), value);
if (is_float) {
double value = 0.0;
size_t index = 0;
try {
value = stod(literal_text, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
}
if (index < literal_text.size()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
}
value = is_pos ? value : -value;
token->data = support::ValueToFloatImm(value, width);
if (!token->data.defined()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "floating point number `" << literal_text
<< "` unrepresentable in width " << width);
token->data = support::ValueToFloatImm(0.0, width);
}
} else {
int64_t value = 0;
size_t index = 0;
try {
value = std::stoll(literal_text, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
}
if (index < literal_text.size()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
}
value = is_pos ? value : -value;
token->data = support::ValueToIntImm(value, width);
if (!token->data.defined() && suffix.empty()) {
// Without any i suffix the legacy behavior was to default to int64 if out of range
// for int32.
width = 64;
token->data = support::ValueToIntImm(value, width);
}
if (!token->data.defined()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "integer number `" << literal_text << "` unrepresentable in width "
<< width);
token->data = support::ValueToIntImm(0, width);
}
}

return token;
}

Expand All @@ -230,14 +265,13 @@ struct Tokenizer {
}

bool is_float = false;

// Remove trailing floating point prefix.
if (More() && Peek() == 'f') {
if (More() && (Peek() == 'f' || Peek() == 'i')) {
is_float = Peek() == 'f';
// Capture trailing width suffix
ss << Next();
while (More() && IsNumeric(Peek())) {
ss << Next();
}
is_float = true;
}
return ParseNumber(is_pos, is_float, ss.str());
}
Expand Down
7 changes: 1 addition & 6 deletions src/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode);

class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not have tab or newline.";
}
data_ = runtime::make_object<DocTextNode>(str);
}
explicit DocText(std::string str) { data_ = runtime::make_object<DocTextNode>(str); }

TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode);
};
Expand Down
80 changes: 32 additions & 48 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
#include "../ir/attr_functor.h"
#include "../parser/meta_ref.h"
#include "../relay/analysis/dependency_graph.h"
#include "../support/scalars.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
#include "tvm/runtime/builtin_fp16.h"

namespace tvm {
namespace relay {
Expand All @@ -61,8 +63,17 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
}
// default annotations
if (annotate_ == nullptr) {
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
doc << " /* ty=" << Print(expr->checked_type()) << " */";
if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() ||
expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) &&
(expr->checked_type_.defined() || expr->span.defined())) {
doc << " /*";
if (expr->checked_type_.defined()) {
doc << " ty=" << Print(expr->checked_type());
}
if (expr->span.defined()) {
doc << " span=" << PrintSpan(expr->span);
}
doc << " */";
}
} else {
std::string annotated_expr = annotate_(expr);
Expand Down Expand Up @@ -219,7 +230,7 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
memo_[var] = val; // Referential occurrences will not include the following.
if (!var->virtual_device()->IsFullyUnconstrained()) {
val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
}
Expand Down Expand Up @@ -335,51 +346,17 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
// first time.
Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); }

/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template <typename T>
Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << value;
} else if (dtype == DataType::Float(32)) {
os << value << 'f';
} else if (dtype == DataType::Float(64)) {
os << value << "f64";
} else if (dtype == DataType::Bool()) {
return Doc::PyBoolLiteral(value != 0);
} else {
os << value;
}
return Doc::Text(os.str());
}

Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
// Print out simple scalars directly.
if (op->is_scalar()) {
std::ostringstream os;
DataType dtype = DataType(op->data->dtype);
ICHECK_EQ(op->data->device.device_type, kDLCPU);
if (dtype == DataType::Int(32)) {
return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
} else if (dtype == DataType::Int(64)) {
return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
} else if (dtype == DataType::Float(32)) {
return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
} else if (dtype == DataType::Float(64)) {
return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
} else if (dtype == DataType::Bool()) {
return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
if (support::IsSimpleScalar(op)) {
return Doc::Text(support::NDArrayScalarToString(op->data));
}
// default fall-back, record it as meta node.
// Fallbock: record it as a meta node.
Doc doc;
// Don't append optional_info. Because the entry function is Print,
// and it will append the optional_info afterwards.
return doc << PrintExpr(GetRef<Expr>(op), true, false, false);
return doc << PrintExpr(GetRef<Expr>(op), /*meta=*/true, /*try_inline=*/false,
/*optional_info=*/false);
}

Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
Expand Down Expand Up @@ -540,9 +517,6 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
return doc;
} else {
doc << "(" << Doc::Concat(args) << ")";
if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
}
}
Expand Down Expand Up @@ -799,11 +773,21 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
}

Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
if (support::IsSimpleScalarDtype(op->dtype)) {
return Doc::Text(support::IntImmToString(GetRef<IntImm>(op)));
} else {
// Fallback: Print int64_t without width suffix.
return Doc::Text(std::to_string(op->value));
}
}

Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
if (support::IsSimpleScalarDtype(op->dtype)) {
return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op)));
} else {
// Fallbock: Print double without width suffix.
return Doc::Text(std::to_string(op->value));
}
}

Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
Expand Down Expand Up @@ -977,7 +961,7 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column;
return doc;
}

Expand Down
7 changes: 0 additions & 7 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,6 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
// Should only be triggered when op is a free variable being visited for the
// first time.
Doc VisitExpr_(const VarNode* op) final;
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template <typename T>
static Doc ScalarLiteral(DataType dtype, const T& value);
Doc VisitExpr_(const ConstantNode* op) final;
Doc VisitExpr_(const TupleNode* op) final;
Doc VisitExpr_(const TupleGetItemNode* op) final;
Expand Down
Loading