From ffd113afb0cf9454b88ea4ef1de464863de2a442 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 24 Feb 2023 10:14:57 -0500 Subject: [PATCH] [FIX][MetaSchedule] JSON dump FloatImm at least one decimal In MetaSchedule, when dumping a FloatImm as JSON, at this moment we are using ```c++ os << std::setprecision(20) << float_imm->value; ``` In this way, float values that are "integers" (e.g., `1.0`, `2.0`) will be dumped as strings without decimal. For example, `1.0` will be dumped as `1`, and `2.0` will be dumped as `2`. This lead to error when we parse back the JSON string, as the parser will treat `1` as an IntImm insted of FloatImm. Therefore, this PR aims to ensure that FloatImms are printed with at least one decimal when dumping. We achieve this with the help of `std::modf` in C++ math, which extracts the integral part and fractional part from a double/float value. When the fractional part is `0.0`, we use `std::fixed` to enforce one decimal. --- src/meta_schedule/database/database_utils.cc | 8 +++- .../unittest/test_tir_schedule_trace.py | 37 ++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 389c69fe9c8b..fcbe8276b98e 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -40,7 +41,12 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { os << int_imm->value; } } else if (const auto* float_imm = json_obj.as()) { - os << std::setprecision(20) << float_imm->value; + double int_part; + if (std::modf(float_imm->value, &int_part) == 0.0) { + os << std::fixed << std::setprecision(1) << float_imm->value; + } else { + os << std::setprecision(20) << float_imm->value; + } } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 916db184e09b..6b86240724f8 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -244,6 +244,40 @@ def test_trace_as_json_1(): ] +def test_trace_as_json_floatimm(): + var = tir.Var("v", "int32") + trace1 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[[tvm.tir.IntImm("int32", 3)], [tvm.tir.FloatImm("float32", 1.0)]], + outputs=[var], + ) + ], + decisions={}, + ) + json1 = trace1.as_json() + assert json1 == [[["SampleCategorical", [], [[3], [1.0]], ["v0"]]], []] + + trace2 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[ + [tvm.tir.IntImm("int32", 3), tvm.tir.IntImm("int32", 4)], + [tvm.tir.FloatImm("float32", 0.5), tvm.tir.FloatImm("float32", 0.5)], + ], + outputs=[var], + ) + ], + decisions={}, + ) + json2 = trace2.as_json() + assert json2 == [[["SampleCategorical", [], [[3, 4], [0.5, 0.5]], ["v0"]]], []] + + def test_trace_simplified_1(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( @@ -367,5 +401,4 @@ def test_apply_annotation_from_json(): if __name__ == "__main__": - test_trace_simplified_2() - # tvm.testing.main() + tvm.testing.main()