diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 8fb83147ea7b..ea4e81c16f0c 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -128,6 +128,13 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { + ICHECK_EQ(inst->attrs.size(), 2); + std::vector probs = + support::AsVector(Downcast>(inst->attrs[1])); + if (probs.size() == 1) { + // Skip mutating the sampling instructions who have only single candidate. + continue; + } const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py index 0600c0b79194..c09ef3e87066 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -95,5 +95,18 @@ def test_mutate_tile_size_matmul(): assert len(results) > 15 +def test_mutate_sample_categorical_single_candidate(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + ) + sch = Schedule(matmul, debug_mask="all") + sch.sample_categorical(candidates=[1], probs=[1.0], decision=0) + + # The mutator finds the SampleCategorical has only one candidate, and thus skips it. + trace = mutator.apply(sch.trace) + assert trace is None + + if __name__ == "__main__": test_mutate_tile_size_matmul() + test_mutate_sample_categorical_single_candidate()