Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the compatibility issue in func FunctionIPv4NumToString #8210

Merged
merged 12 commits into from
Oct 19, 2023
88 changes: 59 additions & 29 deletions dbms/src/Functions/FunctionsCoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,55 +808,85 @@ class FunctionIPv4NumToString : public IFunction

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!checkDataType<DataTypeUInt32>(&*arguments[0]))
throw Exception(
fmt::format(
"Illegal type {} of argument of function {}, expected UInt32",
arguments[0]->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

return std::make_shared<DataTypeString>();
if (arguments[0]->isInteger())
return makeNullable(std::make_shared<DataTypeString>());
throw Exception(
fmt::format(
"Illegal type {} of argument of function {}, expected integer",
arguments[0]->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

bool useDefaultImplementationForNulls() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override
template <typename ColumnContainer>
static void executeImplColumnInteger(Block & block, const ColumnContainer & vec_in, size_t result)
{
const ColumnPtr & column = block.getByPosition(arguments[0]).column;
auto col_res = ColumnString::create();
auto nullmap_res = ColumnUInt8::create();
ColumnString::Chars_t & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();
ColumnUInt8::Container & vec_res_nullmap = nullmap_res->getData();

if (const auto * col = typeid_cast<const ColumnUInt32 *>(column.get()))
{
const ColumnUInt32::Container & vec_in = col->getData();

auto col_res = ColumnString::create();
vec_res.resize(vec_in.size() * (IPV4_MAX_TEXT_LENGTH + 1)); /// the longest value is: 255.255.255.255\0
offsets_res.resize(vec_in.size());
vec_res_nullmap.assign(vec_in.size(), static_cast<UInt8>(0));

ColumnString::Chars_t & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();

vec_res.resize(vec_in.size() * (IPV4_MAX_TEXT_LENGTH + 1)); /// the longest value is: 255.255.255.255\0
offsets_res.resize(vec_in.size());
char * begin = reinterpret_cast<char *>(&vec_res[0]);
char * pos = begin;
char * begin = reinterpret_cast<char *>(&vec_res[0]);
char * pos = begin;

for (size_t i = 0; i < vec_in.size(); ++i)
for (size_t i = 0; i < vec_in.size(); ++i)
{
auto && value = vec_in[i];
if (/*always `false` for unsigned integer*/ value < 0
|| /*auto optimized by compiler*/ static_cast<UInt64>(value) > std::numeric_limits<UInt32>::max())
{
formatIP<mask_tail_octets>(vec_in[i], pos);
offsets_res[i] = pos - begin;
*pos++ = 0;
vec_res_nullmap[i] = 1;
}
else
{
formatIP<mask_tail_octets>(static_cast<UInt32>(value), pos);
}
offsets_res[i] = pos - begin;
}

vec_res.resize(pos - begin);
vec_res.resize(pos - begin);
block.getByPosition(result).column = ColumnNullable::create(std::move(col_res), std::move(nullmap_res));
}

block.getByPosition(result).column = std::move(col_res);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override
{
const ColumnPtr & column = block.getByPosition(arguments[0]).column;

#define DISPATCH(ColType) \
else if (const auto * col = typeid_cast<const ColType *>(column.get())) \
{ \
const typename ColType::Container & vec_in = col->getData(); \
executeImplColumnInteger(block, vec_in, result); \
}

if (false) {} // NOLINT
DISPATCH(ColumnUInt64)
DISPATCH(ColumnInt64)
DISPATCH(ColumnUInt32)
DISPATCH(ColumnInt32)
DISPATCH(ColumnUInt16)
DISPATCH(ColumnInt16)
DISPATCH(ColumnUInt8)
DISPATCH(ColumnInt8)
else
{
throw Exception(
fmt::format(
"Illegal column {} of argument of function {}",
block.getByPosition(arguments[0]).column->getName(),
getName()),
ErrorCodes::ILLEGAL_COLUMN);
}
#undef DISPATCH
}
};

Expand Down
100 changes: 81 additions & 19 deletions dbms/src/Functions/tests/gtest_inet_aton_ntoa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <TestUtils/TiFlashTestBasic.h>

#include <random>
#include <string>
#include <vector>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-compare"
Expand Down Expand Up @@ -133,35 +131,99 @@ try
}
CATCH

TEST_F(TestInetAtonNtoa, InetNtoa)
try

template <typename Type>
static void TestInetAtonNtoaImpl(TestInetAtonNtoa & test)
{
const String func_name = "IPv4NumToString";

// empty column
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({}),
executeFunction(func_name, createColumn<Nullable<UInt32>>({})));
test.executeFunction(func_name, createColumn<Nullable<Type>>({})));

ASSERT_COLUMN_EQ(createColumn<String>({}), executeFunction(func_name, createColumn<UInt32>({})));
ASSERT_COLUMN_EQ(createColumn<Nullable<String>>({}), test.executeFunction(func_name, createColumn<Type>({})));

// const null-only column
ASSERT_COLUMN_EQ(
createConstColumn<Nullable<String>>(1, {}),
executeFunction(func_name, createConstColumn<Nullable<UInt32>>(1, {})));
test.executeFunction(func_name, createConstColumn<Nullable<Type>>(1, {})));

// const non-null column
ASSERT_COLUMN_EQ(
createConstColumn<String>(1, "0.0.0.1"),
executeFunction(func_name, createConstColumn<Nullable<UInt32>>(1, 1)));
if constexpr (std::is_same_v<UInt8, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({"0.0.0.255"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<Int8, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({{}, "0.0.0.127"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({-1, std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<UInt16, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({"0.0.255.255"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<Int16, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({{}, "0.0.127.255"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({-1, std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<UInt32, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({"255.255.255.255"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({std::numeric_limits<Type>::max()})));
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>(
{"1.2.3.4", "0.1.0.1", "0.255.0.255", "0.1.2.3", "0.0.0.0", "1.0.1.0", "111.0.21.12"}),
test.executeFunction(
func_name,
createColumn<Nullable<Type>>({16909060, 65537, 16711935, 66051, 0, 16777472, 1862276364})));
}
else if constexpr (std::is_same_v<Int32, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({{}, "127.255.255.255"}),
test.executeFunction(func_name, createColumn<Nullable<Type>>({-1, std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<UInt64, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({"255.255.255.255", {}}),
test.executeFunction(
func_name,
createColumn<Nullable<Type>>({std::numeric_limits<UInt32>::max(), std::numeric_limits<Type>::max()})));
}
else if constexpr (std::is_same_v<Int64, Type>)
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>({"255.255.255.255", {}, {}}),
test.executeFunction(
func_name,
createColumn<Nullable<Type>>(
{std::numeric_limits<UInt32>::max(), -1, std::numeric_limits<Type>::max()})));
}
}

// normal cases
ASSERT_COLUMN_EQ(
createColumn<Nullable<String>>(
{"1.2.3.4", "0.1.0.1", "0.255.0.255", "0.1.2.3", "0.0.0.0", "1.0.1.0", "111.0.21.12"}),
executeFunction(
func_name,
createColumn<Nullable<UInt32>>({16909060, 65537, 16711935, 66051, 0, 16777472, 1862276364})));

TEST_F(TestInetAtonNtoa, InetNtoa)
try
{
#define M(T) TestInetAtonNtoaImpl<T>(*this);
M(UInt8);
M(Int8);
M(UInt16);
M(Int16);
M(UInt32);
M(Int32);
M(UInt64);
M(Int64);
#undef M
}
CATCH

Expand All @@ -176,7 +238,7 @@ try
std::uniform_int_distribution<UInt32> dist;

InferredDataVector<Nullable<UInt32>> num_vec;
for (size_t i = 0; i < 10000; ++i)
for (size_t i = 0; i < 512; ++i)
{
num_vec.emplace_back(dist(mt));
}
Expand Down