Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507) #9516

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 120 additions & 67 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1681,26 +1681,41 @@ class FunctionSubstringUTF8 : public IFunction
bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
// Int64 / UInt64
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs]
= getValueFromStartField<StartFieldType>((*block.getByPosition(arguments[1]).column)[0]);
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
length = getValueFromLengthField<LengthFieldType>(
(*block.getByPosition(arguments[2]).column)[0]);
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});

Expand Down Expand Up @@ -1735,15 +1750,15 @@ class FunctionSubstringUTF8 : public IFunction
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartField<StartFieldType>((*column_start)[0]);
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [&column_start](size_t i) {
return getValueFromStartField<StartFieldType>((*column_start)[i]);
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

Expand All @@ -1756,26 +1771,36 @@ class FunctionSubstringUTF8 : public IFunction
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if (!is_length_type_valid)
if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
Expand Down Expand Up @@ -1813,10 +1838,38 @@ class FunctionSubstringUTF8 : public IFunction
return true;
});

if (!is_start_type_valid)
if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
return val < 0 ? 0 : val;
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

private:
using VectorConstConstFunc = std::function<void(
const ColumnString::Chars_t &,
Expand All @@ -1840,49 +1893,40 @@ class FunctionSubstringUTF8 : public IFunction
}
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
}

// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = start_field.get<Int64>();

if (signed_length < 0)
{
return {false, static_cast<size_t>(-signed_length)};
}
else
{
return {true, static_cast<size_t>(signed_length)};
}
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return {true, start_field.get<UInt64>()};
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}
};

Expand Down Expand Up @@ -1921,16 +1965,28 @@ class FunctionRightUTF8 : public IFunction
bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
0);

// for const 0, return const blank string.
if (0 == length)
Expand All @@ -1950,8 +2006,10 @@ class FunctionRightUTF8 : public IFunction
else
{
// vector vector
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
Expand All @@ -1970,8 +2028,10 @@ class FunctionRightUTF8 : public IFunction
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::constVector(
column_length->size(),
Expand All @@ -1998,22 +2058,15 @@ class FunctionRightUTF8 : public IFunction
template <typename F>
static bool getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}
};

Expand Down
Loading