From 8b695503b55ba11d71b93a5612a31d966e38cb60 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 4 Nov 2021 14:01:58 -0700 Subject: [PATCH] [MetaSchedule] Sample-Perfect-Tile (#501) --- include/tvm/tir/schedule/schedule.h | 10 + python/tvm/tir/schedule/schedule.py | 33 ++ src/tir/schedule/analysis/analysis.cc | 4 +- src/tir/schedule/concrete_schedule.cc | 10 + src/tir/schedule/concrete_schedule.h | 25 +- src/tir/schedule/primitive.h | 23 +- src/tir/schedule/primitive/sampling.cc | 316 +++++++++++++++++- src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 16 + src/tir/schedule/traced_schedule.h | 11 +- src/tir/schedule/utils.h | 42 ++- src/tir/transforms/compact_buffer_region.cc | 2 +- .../unittest/test_tir_schedule_sampling.py | 43 ++- 13 files changed, 491 insertions(+), 46 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 2a12890b2641..fa4be5dee47d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -200,6 +200,16 @@ class ScheduleNode : public runtime::Object { */ virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) = 0; + /*! + * \brief Sample the factors to perfect tile a specific loop + * \param loop_rv The loop to be tiled + * \param n The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ + virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d5ee2fed6eb5..7ec8e20d41f9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -325,6 +325,39 @@ def sample_categorical( decision, ) + def sample_perfect_tile( + self, + loop: LoopRV, + n: int, + max_innermost_factor: int = 16, + decision: Optional[List[int]] = None, + ) -> List[ExprRV]: + """Sample the factors to perfect tile a specific loop + + Parameters + ---------- + loop : LoopRV + The loop to be tiled + n : int + The number of tiles to be sampled + max_innermost_factor : int + The maximum tile size allowed to be sampled in the innermost loop + decision: Optional[List[int]] + The sampling decision, if any + + Returns + ------- + result : List[ExprRV] + A list of length `n`, the random perfect tile sizes sampled + """ + return _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member + self, + loop, + n, + max_innermost_factor, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d14d64a4c787..e3a535e9b3d4 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -505,8 +505,8 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, if (const ForNode* loop = p->StmtAs()) { if (loop->kind == ForKind::kThreadBinding) { const String& thread_tag = loop->thread_binding.value()->thread_tag; - if (CanRelaxStorageUndereThread(extra_relax_scope, - runtime::ThreadScope::Create(thread_tag))) { + if (CanRelaxStorageUnderThread(extra_relax_scope, + runtime::ThreadScope::Create(thread_tag))) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 944d7c7ef111..54760abbe521 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -232,6 +232,16 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, throw; } +Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision)); + TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 56af2e8768dd..d053e3329fce 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -81,16 +81,10 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) override; + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; @@ -162,6 +156,12 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variable created */ inline ExprRV CreateRV(int64_t value); + /*! + * \brief Add a list of integers as random variables into the symbol table + * \param value The list of integers to be added to the symbol table + * \return The new random variables created + */ + inline Array CreateRV(const std::vector& value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -295,6 +295,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { + Array results; + results.reserve(value.size()); + for (int64_t v : value) { + results.push_back(CreateRV(v)); + } + return results; +} + inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 8fe3018d2c82..aa726c48b024 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -32,8 +32,8 @@ namespace tir { * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ -TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive); +TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t min_inclusive, int32_t max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update @@ -46,6 +46,25 @@ TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Sample the factors to perfect tile a specific loop + * \param rand_state The random state + * \param loop_sref The loop to be tiled + * \param n The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_splits); +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_split, int32_t max_innermost_factor); +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, + Optional>* decision); /******** Schedule: Get blocks & loops ********/ /*! diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6ac6226118cd..4acf61860112 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,30 +20,156 @@ #include #include "../utils.h" +#include "tvm/support/random_engine.h" namespace tvm { namespace tir { -int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive) { +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int32_t kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int32_t kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int32_t min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int32_t i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int32_t x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int64_t factor = primes[i]; + int64_t y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), static_cast(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int32_t prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int32_t n) const { + std::vector> result; + result.reserve(16); + int32_t i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int32_t j; n > 1;) { + int32_t i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + +int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, + int32_t max_exclusive) { CHECK(min_inclusive < max_exclusive) << "ValueError: max_exclusive must be greater than min_inclusive."; if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } support::LinearCongruentialEngine rand_(rand_state); - std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); return dist(rand_); } +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) { + if (k == 1) { + return {SampleInt(rand_state, 0, n)}; + } + if (k == 2) { + int32_t result0 = SampleInt(rand_state, 0, n); + int32_t result1 = SampleInt(rand_state, 0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int32_t i = 0; i < n; ++i) { + order[i] = i; + } + for (int32_t i = 0; i < k; ++i) { + int32_t j = SampleInt(rand_state, i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; - int i = -1; - int n = candidates.size(); - + int32_t i = -1; + int32_t n = candidates.size(); if (decision->defined()) { const auto* int_imm = decision->as(); i = int_imm->value; @@ -51,7 +177,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } else { std::vector weights = support::AsVector(probs); - std::discrete_distribution dist(weights.begin(), weights.end()); + std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n @@ -62,6 +188,151 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int32_t result0 = 1; + int32_t result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + int32_t x1 = SampleInt(rand_state, 0, p + 1); + int32_t x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(rand_state, 0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + if (p == 1) { + result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = + SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int32_t i = 0, last = -1; i < n_splits; ++i) { + int32_t x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; +} + +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits, + int32_t max_innermost_factor) { + if (max_innermost_factor == -1) { + return SamplePerfectTile(rand_state, extent, n_splits); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int32_t i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); + result.push_back(innermost); + return result; +} + +std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + Optional>* decision) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + int64_t extent = GetLoopIntExtent(loop); + std::vector result; + if (extent == -1) { + // Case 1. Handle loops with non-constant length + result = std::vector(n_splits, 1); + result[0] = -1; + } else if (decision->defined()) { + // Case 2. Use previous decision + result = support::AsVector(decision->value()); + int n = result.size(); + ICHECK_GE(n, 2); + int64_t len = extent; + for (int i = n - 1; i > 0; --i) { + int64_t& l = result[i]; + // A previous decision could become invalid because of the change of outer tiles + // To handle this case properly, we check if the tiling strategy is still perfect. + // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles + if (len % l != 0) { + l = len; + } + len /= l; + } + result[0] = len; + } else { + // Case 3. Use fresh new sampling result + result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + ICHECK_LE(result.back(), max_innermost_factor); + } + *decision = support::AsArray(result); + return result; +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -96,7 +367,38 @@ struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SamplePerfectTile"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer max_innermost_factor, + Optional> decision) { + return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, + Integer max_innermost_factor, Optional> decision) { + PythonAPICall py("sample_perfect_tile"); + py.Input("loop", loop_rv); + py.Input("n", n->value); + py.Input("max_innermost_factor", max_innermost_factor->value); + py.Decision(decision); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); +TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 12287a9d7433..8f7caa914530 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,6 +123,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") /******** (FFI) Sampling ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") + .set_body_method(&ScheduleNode::SamplePerfectTile); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 5c86d412cc50..94f15d5c6543 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -43,6 +43,7 @@ Schedule TracedScheduleNode::Copy() const { } /******** Schedule: Sampling ********/ + ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { @@ -57,6 +58,21 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, return result; } +Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + Array results = CreateRV(tir::SamplePerfectTile( + &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{Integer(n), Integer(max_innermost_factor)}, + /*outputs=*/{results.begin(), results.end()}), + /*decision=*/decision); + return results; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index c204312634f5..d5676f4cdce7 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,17 +47,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) final; - + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index a63a9f079617..c66c2ca76693 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -53,8 +53,8 @@ namespace tir { * \brief A helper macro to convert an sref to the statement it points to, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param SRef The SRef to be casted - * \param Type The type to be casted to, can be Block or For + * \param SRef The SRef to be cast + * \param Type The type to be cast to, can be Block or For */ #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ SRef->StmtAs(); \ @@ -64,7 +64,7 @@ namespace tir { * \brief A helper macro to convert an sref to the block it points to, * throwing an internal error if downcasting fails * \param Result The result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_BLOCK(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ @@ -75,7 +75,7 @@ namespace tir { * \brief A helper macro to convert an sref to the for-loop it points to, * throwing an internal error if downcasting fails * \param Result The name of the result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_FOR(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ @@ -86,8 +86,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ From.as(); \ @@ -97,8 +97,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * throwing an internal error if downcast fails. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS(Result, From, Type) \ TVM_TYPE_AS_OR_ERR(Result, From, Type) \ @@ -129,8 +129,8 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { * \param thread_scope The thread scope to be relaxed * \return A boolean indicating the result */ -inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_scope, - const runtime::ThreadScope& thread_scope) { +inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scope, + const runtime::ThreadScope& thread_scope) { if (storage_scope.rank == runtime::StorageRank::kWarp) { // for warp memory, we only relax threadIdx.x return thread_scope.rank == 1 && thread_scope.dim_index == 0; @@ -210,6 +210,28 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } +/**************** Loop extents ****************/ + +/*! + * \brief Get the extents of a loop + * \param loop The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const ForNode* loop) { + const auto* int_extent = loop->extent.as(); + return int_extent ? int_extent->value : -1; +} + +/*! + * \brief Get the extents of a loop + * \param loop_sref The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + return GetLoopIntExtent(loop); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a1f488f386b3..36f0a3488cce 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -232,7 +232,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { const String& thread_tag = loop->thread_binding.value()->thread_tag; // When there is warp memory // threadIdx.x must be set to be warp index. - return CanRelaxStorageUndereThread(scope, runtime::ThreadScope::Create(thread_tag)); + return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create(thread_tag)); } /**************** Class members ****************/ diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index c93c7ca63aa8..0a8b1cc13be9 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -14,15 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys from collections import defaultdict +import sys import pytest -import tvm + from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.schedule import Trace # pylint: disable=no-member,invalid-name,unused-variable @@ -30,9 +29,9 @@ @T.prim_func def elementwise(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128, 128)) - B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: + A = T.match_buffer(a, (128, 257, 1470)) + B = T.match_buffer(b, (128, 257, 1470)) + with T.block([128, 257, 1470], "B") as [vi, vj, vk]: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -40,7 +39,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: def test_sample_categorical(): - """Test sample categprical sampling function""" + """Test sample categorical sampling function""" n = 1000 sch = tir.Schedule(elementwise, seed=42, debug_mask="all") counter = defaultdict(int) @@ -85,5 +84,35 @@ def test_sample_categorical_serialize(): assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] +def test_sample_perfect_tile_power_of_two(): + sch = tir.Schedule(elementwise, debug_mask="all") + i, _, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 128 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_prime(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 257 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_composite(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, _, i = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1470 + verify_trace_roundtrip(sch, mod=elementwise) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))