Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong result of cast(float as decimal) when overflow happens (#4380) #4389

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dbms/src/Functions/FunctionsTiDBConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,8 @@ struct TiDBConvertToDecimal
static_assert(std::is_floating_point_v<FromFieldType>);
/// cast real as decimal
for (size_t i = 0; i < size; ++i)
vec_to[i] = toTiDBDecimal<FromFieldType, ToFieldType>(vec_from[i], prec, scale, context);
// Always use Float64 to avoid overflow for vec_from[i] * 10^scale.
vec_to[i] = toTiDBDecimal<Float64, ToFieldType>(static_cast<Float64>(vec_from[i]), prec, scale, context);
}
}
else
Expand Down
97 changes: 62 additions & 35 deletions dbms/src/Functions/tests/bench_function_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class CastToDecimalBench : public benchmark::Fixture
DataTypePtr from_type_dec_60_5 = std::make_shared<DataTypeDecimal256>(60, 5);
DataTypePtr from_type_date = std::make_shared<DataTypeMyDate>();
DataTypePtr from_type_datetime_fsp5 = std::make_shared<DataTypeMyDateTime>(5);
DataTypePtr from_type_float32 = std::make_shared<DataTypeFloat32>();
DataTypePtr from_type_float64 = std::make_shared<DataTypeFloat64>();

auto tmp_col_int8 = from_type_int8->createColumn();
auto tmp_col_int16 = from_type_int16->createColumn();
Expand All @@ -95,6 +97,8 @@ class CastToDecimalBench : public benchmark::Fixture
auto tmp_col_dec_60_5 = from_type_dec_60_5->createColumn();
auto tmp_col_date = from_type_date->createColumn();
auto tmp_col_datetime_fsp5 = from_type_date->createColumn();
auto tmp_col_float32 = ColumnFloat32::create();
auto tmp_col_float64 = ColumnFloat64::create();

std::uniform_int_distribution<int64_t> dist64(std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());

Expand All @@ -120,6 +124,8 @@ class CastToDecimalBench : public benchmark::Fixture
tmp_col_uint16->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_uint32->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_uint64->insert(Field(static_cast<Int64>(static_cast<UInt16>(dist64(mt)))));
tmp_col_float32->insert(static_cast<Float32>(dist64(mt)));
tmp_col_float64->insert(static_cast<Float64>(dist64(mt)));

tmp_col_dec_2_1->insert(DecimalField(Decimal(static_cast<Int32>(dist64(mt) % 100)), 1));
tmp_col_dec_2_1_small->insert(DecimalField(Decimal(static_cast<Int32>(dist64(mt) % 10)), 1));
Expand All @@ -145,6 +151,8 @@ class CastToDecimalBench : public benchmark::Fixture
from_col_uint16 = ColumnWithTypeAndName(std::move(tmp_col_uint16), from_type_uint16, "from_col_uint16");
from_col_uint32 = ColumnWithTypeAndName(std::move(tmp_col_uint32), from_type_uint32, "from_col_uint32");
from_col_uint64 = ColumnWithTypeAndName(std::move(tmp_col_uint64), from_type_uint64, "from_col_uint64");
from_col_float32 = ColumnWithTypeAndName(std::move(tmp_col_float32), from_type_float32, "from_col_float32");
from_col_float64 = ColumnWithTypeAndName(std::move(tmp_col_float64), from_type_float64, "from_col_float64");

from_col_dec_2_1 = ColumnWithTypeAndName(std::move(tmp_col_dec_2_1), from_type_dec_2_1, "from_col_dec_2_1");
from_col_dec_2_1_small = ColumnWithTypeAndName(std::move(tmp_col_dec_2_1_small), from_type_dec_2_1_small, "from_col_dec_2_1_small");
Expand Down Expand Up @@ -203,16 +211,22 @@ class CastToDecimalBench : public benchmark::Fixture
from_int64_vec = std::vector<Int64>(row_num);
from_int128_vec = std::vector<Int128>(row_num);
from_int256_vec = std::vector<Int256>(row_num);
from_float32_vec = std::vector<Float32>(row_num);
from_float64_vec = std::vector<Float64>(row_num);
dest_int64_vec = std::vector<Int64>(row_num);
dest_int128_vec = std::vector<Int128>(row_num);
dest_int256_vec = std::vector<Int256>(row_num);
dest_float32_vec = std::vector<Float32>(row_num);
dest_float64_vec = std::vector<Float64>(row_num);
const Int256 mod_prec_19 = getScaleMultiplier<Decimal256>(19);
const Int256 mod_prec_38 = getScaleMultiplier<Decimal256>(38);
for (auto i = 0; i < row_num; ++i)
{
from_int64_vec[i] = dist64(mt);
from_int128_vec[i] = static_cast<Int128>(dist256(mt) % (std::numeric_limits<Int128>::max() % mod_prec_19));
from_int256_vec[i] = static_cast<Int256>(dist256(mt) % (std::numeric_limits<Int256>::max()) % mod_prec_38);
from_float32_vec[i] = static_cast<Float32>(from_int64_vec[i]);
from_float64_vec[i] = static_cast<Float64>(from_int64_vec[i]);
}
}

Expand All @@ -227,6 +241,8 @@ class CastToDecimalBench : public benchmark::Fixture
ColumnWithTypeAndName from_col_uint16;
ColumnWithTypeAndName from_col_uint32;
ColumnWithTypeAndName from_col_uint64;
ColumnWithTypeAndName from_col_float32;
ColumnWithTypeAndName from_col_float64;
ColumnWithTypeAndName from_col_dec_2_1;
ColumnWithTypeAndName from_col_dec_2_1_small;
ColumnWithTypeAndName from_col_dec_3_0;
Expand Down Expand Up @@ -267,9 +283,13 @@ class CastToDecimalBench : public benchmark::Fixture
std::vector<Int64> from_int64_vec;
std::vector<Int128> from_int128_vec;
std::vector<Int256> from_int256_vec;
std::vector<Float32> from_float32_vec;
std::vector<Float64> from_float64_vec;
std::vector<Int64> dest_int64_vec;
std::vector<Int128> dest_int128_vec;
std::vector<Int256> dest_int256_vec;
std::vector<Float32> dest_float32_vec;
std::vector<Float64> dest_float64_vec;
};

#define CAST_BENCHMARK(CLASS_NAME, CASE_NAME, FROM_COL, DEST_TYPE) \
Expand Down Expand Up @@ -334,6 +354,9 @@ CAST_BENCHMARK(CastToDecimalBench, int32_to_decimal_60_0, from_col_int32, dest_c
// no; Int64; Int256
CAST_BENCHMARK(CastToDecimalBench, int32_to_decimal_60_4, from_col_int32, dest_col_dec_60_4);

CAST_BENCHMARK(CastToDecimalBench, float32_to_decimal_60_30, from_col_float32, dest_col_dec_60_30);
CAST_BENCHMARK(CastToDecimalBench, float64_to_decimal_60_30, from_col_float64, dest_col_dec_60_30);

// need; Int128; Int32
CAST_BENCHMARK(CastToDecimalBench, int64_to_decimal_8_0, from_col_int64, dest_col_dec_8_0);
// need; Int128; Int64
Expand Down Expand Up @@ -410,44 +433,48 @@ STATIC_CAST_BENCHMARK(CastToDecimalBench, 64, 256);
STATIC_CAST_BENCHMARK(CastToDecimalBench, 128, 128);
STATIC_CAST_BENCHMARK(CastToDecimalBench, 128, 256);

#define DIV_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, div_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_int##TYPE##_vec[i] = from_int##TYPE##_vec[i] / from_int##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
#define DIV_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, div_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_##TYPE##_vec[i] = from_##TYPE##_vec[i] / from_##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
BENCHMARK_REGISTER_F(CastToDecimalBench, div_##TYPE)->Iterations(1000);

DIV_BENCHMARK(CastToDecimalBench, 64);
DIV_BENCHMARK(CastToDecimalBench, 128);
DIV_BENCHMARK(CastToDecimalBench, 256);

#define MUL_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, mul_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_int##TYPE##_vec[i] = from_int##TYPE##_vec[i] * from_int##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
DIV_BENCHMARK(CastToDecimalBench, int64);
DIV_BENCHMARK(CastToDecimalBench, int128);
DIV_BENCHMARK(CastToDecimalBench, int256);
DIV_BENCHMARK(CastToDecimalBench, float32);
DIV_BENCHMARK(CastToDecimalBench, float64);

#define MUL_BENCHMARK(CLASS_NAME, TYPE) \
BENCHMARK_DEFINE_F(CastToDecimalBench, mul_##TYPE) \
(benchmark::State & state) \
try \
{ \
for (auto _ : state) \
{ \
for (int i = 0; i < row_num; ++i) \
{ \
dest_##TYPE##_vec[i] = from_##TYPE##_vec[i] * from_##TYPE##_vec[0]; \
} \
} \
} \
CATCH \
BENCHMARK_REGISTER_F(CastToDecimalBench, mul_##TYPE)->Iterations(1000);

MUL_BENCHMARK(CastToDecimalBench, 64);
MUL_BENCHMARK(CastToDecimalBench, 128);
MUL_BENCHMARK(CastToDecimalBench, 256);
MUL_BENCHMARK(CastToDecimalBench, int64);
MUL_BENCHMARK(CastToDecimalBench, int128);
MUL_BENCHMARK(CastToDecimalBench, int256);
MUL_BENCHMARK(CastToDecimalBench, float32);
MUL_BENCHMARK(CastToDecimalBench, float64);
} // namespace tests
} // namespace DB
15 changes: 15 additions & 0 deletions dbms/src/Functions/tests/gtest_tidb_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,21 @@ try
testNotOnlyNull<Float64, Decimal256>(12.215, DecimalField256(static_cast<Int256>(1222), 2), std::make_tuple(65, 2));
testNotOnlyNull<Float64, Decimal256>(-12.215, DecimalField256(static_cast<Int256>(-1222), 2), std::make_tuple(65, 2));

// Not compatible with MySQL/TiDB.
// MySQL/TiDB: 34028199169636080000000000000000000000.00
// TiFlash: 34028199169636079590747176440761942016.00
testNotOnlyNull<Float32, Decimal256>(3.40282e+37f, DecimalField256(Decimal256(Int256("3402819916963607959074717644076194201600")), 2), std::make_tuple(50, 2));
// MySQL/TiDB: 34028200000000000000000000000000000000.00
// TiFlash: 34028200000000004441521809130870213181.44
testNotOnlyNull<Float64, Decimal256>(3.40282e+37, DecimalField256(Decimal256(Int256("3402820000000000444152180913087021318144")), 2), std::make_tuple(50, 2));

// MySQL/TiDB: 123.12345886230469000000
// TiFlash: 123.12345886230470197248
testNotOnlyNull<Float32, Decimal256>(123.123456789123456789f, DecimalField256(Decimal256(Int256("12312345886230470197248")), 20), std::make_tuple(50, 20));
// MySQL/TiDB: 123.12345886230469000000
// TiFlash: 123.12345678912344293376
testNotOnlyNull<Float64, Decimal256>(123.123456789123456789, DecimalField256(Decimal256(Int256("12312345678912344293376")), 20), std::make_tuple(50, 20));

dag_context->setFlags(ori_flags);
dag_context->clearWarnings();
}
Expand Down