Skip to content

Commit

Permalink
Fix CreatePrimFunc for link-params=True case
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 15, 2022
1 parent 29aa4ee commit e91f675
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ bool RewriteLayout(const Schedule& sch) {
BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name);
if (i == 0) {
// Before the first cache_read that consumes the layout-free buffer, insert
// a layout-rewrite block. Another cache read buffer is added, and its layout is
// a layout-rewrite block. Another cache-read buffer is added, and its layout is
// transformed by TransformLayout below.
add_layout_rewrite_block(cache_read_block_rv, 0);
}
Expand Down
9 changes: 8 additions & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
BlockNode* n = block.CopyOnWrite();
if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) {
Array<Buffer> new_buffers;
for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
auto it = buffer2index_.find(buffer);
if (it != buffer2index_.end()) {
layout_free_buffer_indices_.insert(it->second);
} else {
new_buffers.push_back(buffer);
}
}
n->annotations.erase(topi_attr);
if (new_buffers.empty()) {
n->annotations.erase(topi_attr);
} else {
n->annotations.Set(topi_attr, new_buffers);
}
}
for (const String& attr : this->blocklist) {
auto it = n->annotations.find(attr);
Expand Down

0 comments on commit e91f675

Please sign in to comment.