Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support printing StmtDoc in PythonDocPrinter
Browse files Browse the repository at this point in the history
yelite committed Jul 25, 2022

Verified

This commit was signed with the committer’s verified signature.
1 parent e42741d commit 763ff53
Showing 4 changed files with 906 additions and 6 deletions.
22 changes: 22 additions & 0 deletions src/script/printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
@@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
PrintTypedDoc(GetRef<AssignDoc>(doc_node));
} else if (const auto* doc_node = doc.as<IfDocNode>()) {
PrintTypedDoc(GetRef<IfDoc>(doc_node));
} else if (const auto* doc_node = doc.as<WhileDocNode>()) {
PrintTypedDoc(GetRef<WhileDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ForDocNode>()) {
PrintTypedDoc(GetRef<ForDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssertDocNode>()) {
PrintTypedDoc(GetRef<AssertDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
} else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
63 changes: 59 additions & 4 deletions src/script/printer/base_doc_printer.h
Original file line number Diff line number Diff line change
@@ -84,22 +84,22 @@ class DocPrinter {
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;

/*!
* \brief Virtual method to print a IdDoc
* \brief Virtual method to print an IdDoc
*/
virtual void PrintTypedDoc(const IdDoc& doc) = 0;

/*!
* \brief Virtual method to print a AttrAccessDoc
* \brief Virtual method to print an AttrAccessDoc
*/
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;

/*!
* \brief Virtual method to print a IndexDoc
* \brief Virtual method to print an IndexDoc
*/
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;

/*!
* \brief Virtual method to print a OperationDoc
* \brief Virtual method to print an OperationDoc
*/
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;

@@ -133,6 +133,61 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;

/*!
* \brief Virtual method to print a StmtBlockDoc
*/
virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssignDoc
*/
virtual void PrintTypedDoc(const AssignDoc& doc) = 0;

/*!
* \brief Virtual method to print an IfDoc
*/
virtual void PrintTypedDoc(const IfDoc& doc) = 0;

/*!
* \brief Virtual method to print a WhileDoc
*/
virtual void PrintTypedDoc(const WhileDoc& doc) = 0;

/*!
* \brief Virtual method to print a ForDoc
*/
virtual void PrintTypedDoc(const ForDoc& doc) = 0;

/*!
* \brief Virtual method to print a ScopeDoc
*/
virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;

/*!
* \brief Virtual method to print an ExprStmtDoc
*/
virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssertDoc
*/
virtual void PrintTypedDoc(const AssertDoc& doc) = 0;

/*!
* \brief Virtual method to print a ReturnDoc
*/
virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;

/*!
* \brief Virtual method to print a FunctionDoc
*/
virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;

/*!
* \brief Virtual method to print a ClassDoc
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
212 changes: 211 additions & 1 deletion src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
@@ -16,11 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/doc.h>

#include <algorithm>
#include <string>

#include "../../support/str_escape.h"
#include "../../support/utils.h"
#include "./base_doc_printer.h"

namespace tvm {
@@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const DictDoc& doc) final;
void PrintTypedDoc(const TupleDoc& doc) final;
void PrintTypedDoc(const SliceDoc& doc) final;
void PrintTypedDoc(const StmtBlockDoc& doc) final;
void PrintTypedDoc(const AssignDoc& doc) final;
void PrintTypedDoc(const IfDoc& doc) final;
void PrintTypedDoc(const WhileDoc& doc) final;
void PrintTypedDoc(const ForDoc& doc) final;
void PrintTypedDoc(const ExprStmtDoc& doc) final;
void PrintTypedDoc(const AssertDoc& doc) final;
void PrintTypedDoc(const ReturnDoc& doc) final;
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;

private:
void NewLineWithoutIndent() { output_ << "\n"; }

template <typename DocType>
void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
bool is_first = true;
@@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
PrintDoc(doc);
}
}

void PrintIndentedBlock(const Array<StmtDoc>& docs) {
IncreaseIndent();
for (const StmtDoc& d : docs) {
NewLine();
PrintDoc(d);
}
if (docs.empty()) {
NewLine();
output_ << "pass";
}
DecreaseIndent();
}

void PrintDecorators(const Array<ExprDoc>& decorators) {
for (const ExprDoc& decorator : decorators) {
output_ << "@";
PrintDoc(decorator);
NewLine();
}
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
<< " cannot have newline.";
output_ << " # " << comment;
}
}

void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
for (const std::string& line : comment_lines) {
output_ << "# " << line;
NewLine();
}
}
}

void PrintBlockComment(const String& comment) {
IncreaseIndent();
NewLine() << "\"\"\"";

std::vector<std::string> comment_lines = support::Split(comment, '\n');
for (const std::string& line : comment_lines) {
if (line.empty()) {
// No indentation on empty line
output_ << "\n";
} else {
NewLine() << line;
}
}

NewLine() << "\"\"\"";
DecreaseIndent();
}
};

void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
@@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
}
}

void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
NewLine();
}
}

void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
PrintJoinedDocs(tuple_doc->elements, ", ");
} else {
PrintDoc(doc->lhs);
}

if (doc->annotation) {
output_ << ": ";
PrintDoc(doc->annotation.value());
}
if (doc->rhs) {
output_ << " = ";
PrintDoc(doc->rhs.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->then_branch);

if (!doc->else_branch.empty()) {
NewLine();
output_ << "else:";
PrintIndentedBlock(doc->else_branch);
}
}

void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "for ";
PrintDoc(doc->lhs);
output_ << " in ";
PrintDoc(doc->rhs);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
output_ << " as ";
PrintDoc(doc->lhs.value());
}
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
PrintDoc(doc->expr);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
output_ << "assert ";
PrintDoc(doc->test);
if (doc->msg.defined()) {
output_ << ", ";
PrintDoc(doc->msg.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
output_ << "return ";
PrintDoc(doc->value);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
for (const AssignDoc& arg_doc : doc->args) {
ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
}

PrintDecorators(doc->decorators);

output_ << "def ";
PrintDoc(doc->name);

output_ << "(";
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);

output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
PrintDecorators(doc->decorators);

output_ << "class ";
PrintDoc(doc->name);
output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

String DocToPythonScript(Doc doc, int indent_spaces) {
PythonDocPrinter printer(indent_spaces);
printer.Append(doc);
Loading

0 comments on commit 763ff53

Please sign in to comment.