diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 389c69fe9c8b..fcbe8276b98e 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -40,7 +41,12 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { os << int_imm->value; } } else if (const auto* float_imm = json_obj.as()) { - os << std::setprecision(20) << float_imm->value; + double int_part; + if (std::modf(float_imm->value, &int_part) == 0.0) { + os << std::fixed << std::setprecision(1) << float_imm->value; + } else { + os << std::setprecision(20) << float_imm->value; + } } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 916db184e09b..6b86240724f8 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -244,6 +244,40 @@ def test_trace_as_json_1(): ] +def test_trace_as_json_floatimm(): + var = tir.Var("v", "int32") + trace1 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[[tvm.tir.IntImm("int32", 3)], [tvm.tir.FloatImm("float32", 1.0)]], + outputs=[var], + ) + ], + decisions={}, + ) + json1 = trace1.as_json() + assert json1 == [[["SampleCategorical", [], [[3], [1.0]], ["v0"]]], []] + + trace2 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[ + [tvm.tir.IntImm("int32", 3), tvm.tir.IntImm("int32", 4)], + [tvm.tir.FloatImm("float32", 0.5), tvm.tir.FloatImm("float32", 0.5)], + ], + outputs=[var], + ) + ], + decisions={}, + ) + json2 = trace2.as_json() + assert json2 == [[["SampleCategorical", [], [[3, 4], [0.5, 0.5]], ["v0"]]], []] + + def test_trace_simplified_1(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( @@ -367,5 +401,4 @@ def test_apply_annotation_from_json(): if __name__ == "__main__": - test_trace_simplified_2() - # tvm.testing.main() + tvm.testing.main()