Skip to content

Commit

Permalink
Add more docstrings and depress warnings for new lowering algorithm. (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Dec 13, 2021
1 parent db6b6ab commit 614fb8a
Showing 1 changed file with 86 additions and 38 deletions.
124 changes: 86 additions & 38 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) {
}

/*!
* \brief Compupte the partially lowered index.
* \brief Aggregate offset on previous axes with the index on current axis.
* \param prev_offset The lowered index accumulated over all axis prior to current axis.
* \param axis Current axis.
* \param index The sparse index on current axis.
* \param ana_ The analyzer used for simplifying expressions. TODO(zihao): make it more cleaner.
* \return The lowered index.
*/
PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index,
Expand Down Expand Up @@ -167,8 +168,14 @@ class SparseBlockCtx {
offset_[nullptr] = Integer(0);
}

Optional<SpIterVar> GetSparseIterVar(const VarNode* var_node) const {
auto it = sp_iter_var_map_.find(var_node);
/*!
* \brief Get sparse iter var corresponding to given variable node in the current scope.
* \param var The variable node in AST.
* \return A optional wrapper of sparse iter var. If var is not a sparse iter var, return
* NullOpt.
*/
Optional<SpIterVar> GetSparseIterVar(const VarNode* var) const {
auto it = sp_iter_var_map_.find(var);
if (it != sp_iter_var_map_.end()) {
return it->second;
} else {
Expand All @@ -177,8 +184,9 @@ class SparseBlockCtx {
}

/*!
* \brief Get coordinate of corresding sparse iter var.
* \brief Get coordinate of corresding sparse iter var in the current scope.
* \param sp_iter_var The compressed iterator.
* \return A PrimExpr representing the coordinate.
*/
PrimExpr GetCoordinate(const SpIterVarNode* sp_iter_var) {
const Axis& axis = sp_iter_var->axis;
Expand All @@ -196,7 +204,10 @@ class SparseBlockCtx {
}
}

/*! \brief TODO
/*!
* \brief Get the real offset in compressed buffer of given sparse iter var.
* \param sp_iter_var The sparse iter var to lookup.
* \return A PrimExpr representing the offset.
*/
PrimExpr GetOffset(const SpIterVarNode* sp_iter_var) {
auto it = offset_.find(sp_iter_var);
Expand All @@ -210,7 +221,11 @@ class SparseBlockCtx {
}
}

/*! \brief TODO
/*!
* \brief Get the indices range in compressed buffer of given sparse iter var.
* \param sp_iter_var The sparse iter var to lookup.
* \return A tuple of PrimExpr, the first elements refers to the start position, and the second
* elements refers the end position.
*/
std::tuple<PrimExpr, PrimExpr> GetIndicesRange(const SpIterVarNode* sp_iter_var) {
PrimExpr prev_off = GetOffset(parent_[sp_iter_var]);
Expand All @@ -219,7 +234,8 @@ class SparseBlockCtx {
AggregateOffset(add(prev_off, 1), axis, Integer(0), &ana_)};
}

/*! \brief TODO
/*!
* \brief Get the current block name.
*/
const String GetBlockName() const { return blk_name_; }

Expand All @@ -231,57 +247,66 @@ class SparseBlockCtx {
String blk_name_;
};

/*! \brief default constructor */
explicit SparseBlockCtx(AxisTree tree) : tree_(std::move(tree)) {}

/*! \brief enter new scope */
void EnterScope(const SparseBlockNode* sp_block) {
stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, tree_);
}

/*! \brief exit current scope */
void ExitScope() { stack_.pop_back(); }

Optional<SpIterVar> GetSparseIterVar(const VarNode* var_node) const {
return local()->GetSparseIterVar(var_node);
/*! \brief call GetSparseIterVar in the top scope. */
Optional<SpIterVar> GetSparseIterVar(const VarNode* node) const {
return top()->GetSparseIterVar(node);
}

PrimExpr GetCoordinate(const SpIterVarNode* node) { return local()->GetCoordinate(node); }
/*! \brief call GetCoordinate in the top scope. */
PrimExpr GetCoordinate(const SpIterVarNode* node) { return top()->GetCoordinate(node); }

/*! \brief call GetIndicesRange in the top scope. */
std::tuple<PrimExpr, PrimExpr> GetIndicesRange(const SpIterVarNode* sp_iter_var) {
return local()->GetIndicesRange(sp_iter_var);
return top()->GetIndicesRange(sp_iter_var);
}

const String GetBlockName() const { return local()->GetBlockName(); }
/*! \brief call GetBlockName in the top scope. */
const String GetBlockName() const { return top()->GetBlockName(); }

private:
std::vector<Scope> stack_;
AxisTree tree_;

inline Scope* local() const { return const_cast<Scope*>(&stack_.back()); }
/*! \brief the top scope in the sparse block stack. */
inline Scope* top() const { return const_cast<Scope*>(&stack_.back()); }
};

/*! \brief Storing the context information of a sparse buffer. */
class SparseBufferCtx {
public:
class Scope {
public:
// move constructor
/*! \brief move constructor */
explicit Scope(Scope&& other)
: buf_name_(std::move(other.buf_name_)),
axes_(std::move(other.axes_)),
offsets_(std::move(other.offsets_)),
matches_(std::move(other.matches_)),
sp_blk_ctx_(std::move(other.sp_blk_ctx_)) {}

// default constructor
/*! \brief default constructor */
explicit Scope(String buf_name, Array<Axis> axes, const SparseBlockCtx* sp_blk_ctx)
: buf_name_(std::move(buf_name)), axes_(std::move(axes)), sp_blk_ctx_(sp_blk_ctx) {
offsets_.emplace_back(Integer(0));
matches_.emplace_back(true);
}

void Register(int idx, PrimExpr coordinate, PrimExpr orig_idx) {
ICHECK(idx + 1 == int(offsets_.size()))
<< "Cannot register coordinate of index " << std::to_string(idx) << " at this time";
const Axis& axis = GetAxis(idx);
/*! \brief register the coordinate of a new dimension of current buffer. */
void Register(int dim, PrimExpr coordinate, PrimExpr orig_idx) {
ICHECK(dim + 1 == int(offsets_.size()))
<< "Cannot register coordinate of index " << std::to_string(dim) << " at this time";
const Axis& axis = GetAxis(dim);

// update matches boolean array
if (!matches_.back()) {
Expand All @@ -305,17 +330,20 @@ class SparseBufferCtx {
offsets_.emplace_back(std::move(new_offset));
}

const Axis& GetAxis(int idx) const {
auto && ret = axes_[idx];
/*! \brief get the axis given dimension index of current buffer. */
const Axis& GetAxis(int dim) const {
auto&& ret = axes_[dim];
return ret;
}

/*! \brief whether the index access pattern of current buffer aligns with current block */
const inline bool MatchWithSpBlock() const { return matches_.back(); }

std::tuple<PrimExpr, PrimExpr> GetIndicesRange(int idx) {
const Axis& axis = axes_[idx];
return {AggregateOffset(offsets_[idx], axis, Integer(0), &ana_),
AggregateOffset(add(offsets_[idx], 1), axis, Integer(0), &ana_)};
/*! \brief return the indices range of the given dimension in current buffer. */
std::tuple<PrimExpr, PrimExpr> GetIndicesRange(int dim) {
const Axis& axis = axes_[dim];
return {AggregateOffset(offsets_[dim], axis, Integer(0), &ana_),
AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), &ana_)};
}

private:
Expand All @@ -327,31 +355,41 @@ class SparseBufferCtx {
const SparseBlockCtx* sp_blk_ctx_;
};

/*! \brief default constructor */
explicit SparseBufferCtx(AxisTree tree) : tree_(std::move(tree)) {}

/*! \brief enter new scope */
void EnterScope(const SparseBuffer& sp_buf, const SparseBlockCtx* sp_blk_ctx) {
stack_.emplace_back(sp_buf->name, sp_buf->axes, sp_blk_ctx);
}

/*! \brief exit current scope */
void ExitScope() { stack_.pop_back(); }

const Axis& GetAxis(int idx) const {
auto&& ret = local()->GetAxis(idx);
/*! \brief call GetAxis in top scope. */
const Axis& GetAxis(int dim) const {
auto&& ret = top()->GetAxis(dim);
return ret;
}

const inline bool MatchWithSpBlock() const { return local()->MatchWithSpBlock(); }
/*! \brief call MatchWithSpBlock in top scope. */
const inline bool MatchWithSpBlock() const { return top()->MatchWithSpBlock(); }

std::tuple<PrimExpr, PrimExpr> GetIndicesRange(int idx) { return local()->GetIndicesRange(idx); }
/*! \brief call GetIndicesRange in top scope. */
std::tuple<PrimExpr, PrimExpr> GetIndicesRange(int dim) { return top()->GetIndicesRange(dim); }

void Register(int idx, PrimExpr coordinate, PrimExpr orig_idx) { local()->Register(idx, std::move(coordinate), std::move(orig_idx)); }
/*! \brief call Register in top scope. */
void Register(int dim, PrimExpr coordinate, PrimExpr orig_idx) {
top()->Register(dim, std::move(coordinate), std::move(orig_idx));
}

private:
AxisTree tree_;
arith::Analyzer ana_;
std::vector<Scope> stack_;

inline Scope* local() const { return const_cast<Scope*>(&stack_.back()); }
/*! \brief the top scope in the sparse buffer stack. */
inline Scope* top() const { return const_cast<Scope*>(&stack_.back()); }
};

/*!
Expand All @@ -361,21 +399,26 @@ class SparseBufferCtx {
class IndexTransformer : public StmtExprMutator {
public:
explicit IndexTransformer(const AxisTree& axis_tree)
: axis_tree_(axis_tree), sp_blk_ctx_(axis_tree), sp_buf_ctx_(axis_tree) {}
: sp_blk_ctx_(axis_tree), sp_buf_ctx_(axis_tree), axis_tree_(axis_tree) {}

private:
// Sparse block context stack;
SparseBlockCtx sp_blk_ctx_;
// Sparse buffer context stack;
SparseBufferCtx sp_buf_ctx_;

PrimExpr ViewIndexInAxis(int idx, PrimExpr index) {
/*!
* \brief Return the offset of index on given dimension.
* \param dim The dimension index.
* \param index The PrimExpr representing the index on this dimension.
*/
PrimExpr ViewIndexInAxis(int dim, PrimExpr index) {
// decompress index to coordinate on iterator axis.
// the index might not be a single var node, use visitor to recursive construct the coordinate.
PrimExpr coordinate = ExprMutator::VisitExpr(index);
const Axis& axis = sp_buf_ctx_.GetAxis(idx);
const Axis& axis = sp_buf_ctx_.GetAxis(dim);
// register to sparse buffer scope
sp_buf_ctx_.Register(idx, coordinate, index);
sp_buf_ctx_.Register(dim, coordinate, index);

PrimExpr offset = index;
// compress coordinate to index on sparse buffer axis.
Expand All @@ -388,14 +431,14 @@ class IndexTransformer : public StmtExprMutator {
case AxisKind::kSparseFixed: {
auto sf_axis = axis.as<SparseFixedAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(idx);
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sf_axis->indices->data, coordinate, l, r);
break;
}
case AxisKind::kSparseVariable:
auto sv_axis = axis.as<SparseVariableAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(idx);
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sv_axis->indices->data, coordinate, l, r);
break;
}
Expand All @@ -404,6 +447,11 @@ class IndexTransformer : public StmtExprMutator {
return offset;
}

/*!
* \brief Compute the offset of given indices in compressed sparse buffer layout.
* \param sp_buffer The sparse buffer to access.
* \param indices The array of indices.
*/
PrimExpr ComputeOffset(SparseBuffer sp_buffer, const Array<PrimExpr>& indices) {
int num_lowered_indices = static_cast<int>(indices.size());
ICHECK_LE(num_lowered_indices, sp_buffer->ndim());
Expand All @@ -426,7 +474,7 @@ class IndexTransformer : public StmtExprMutator {
auto it = sp_blk_ctx_.GetSparseIterVar(v);
if (it.defined()) {
return sp_blk_ctx_.GetCoordinate(it.value().get());
} else{
} else {
return GetRef<PrimExpr>(v);
}
}
Expand Down

0 comments on commit 614fb8a

Please sign in to comment.