diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index de1bc06fe51..dfee078fec1 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -123,7 +123,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListTransformFunction), SCALAR_FUNCTION(ListFilterFunction), SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction), SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), - SCALAR_FUNCTION(ListSingleFunction), + SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index 69dac41249b..d0f62447898 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -24,7 +24,8 @@ add_library(kuzu_list_function list_none.cpp list_single.cpp size_function.cpp - quantifier_functions.cpp) + quantifier_functions.cpp + list_has_all.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/list/list_has_all.cpp b/src/function/list/list_has_all.cpp new file mode 100644 index 00000000000..52e126af4f8 --- /dev/null +++ b/src/function/list/list_has_all.cpp @@ -0,0 +1,67 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListHasAll { + static void operation(common::list_entry_t& left, common::list_entry_t& right, uint8_t& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + int64_t pos = 0; + auto rightDataVector = ListVector::getDataVector(&rightVector); + result = true; + for (auto i = 0u; i < right.size; i++) { + common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(), + [&](T) { + if (rightDataVector->isNull(right.offset + i)) { + return; + } + ListPosition::operation(left, + *(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos, + leftVector, *ListVector::getDataVector(&rightVector), resultVector); + result = (pos != 0); + }); + if (!result) { + return; + } + } + } +}; + +std::unique_ptr bindFunc(ScalarBindFuncInput input) { + std::vector types; + for (auto& arg : input.arguments) { + if (arg->dataType == LogicalType::ANY()) { + types.push_back(LogicalType::LIST(LogicalType::INT64())); + } else { + types.push_back(arg->dataType.copy()); + } + } + if (types[0] != types[1]) { + throw common::BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListHasAllFunction::name, input.arguments[0]->getDataType().toString(), + input.arguments[1]->getDataType().toString())); + } + return std::make_unique(std::move(types), LogicalType::BOOL()); +} + +function_set ListHasAllFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::BOOL, + execFunc, bindFunc); + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index f6f6b2c0867..a3f6832351c 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -203,5 +203,11 @@ struct ListSingleFunction { static function_set getFunctionSet(); }; +struct ListHasAllFunction { + static constexpr const char* name = "LIST_HAS_ALL"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu diff --git a/test/test_files/function/list.test b/test/test_files/function/list.test index d0a3b363ad0..a86e1ed9848 100644 --- a/test/test_files/function/list.test +++ b/test/test_files/function/list.test @@ -2258,3 +2258,56 @@ a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11*a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a12*a0eebc -STATEMENT RETURN [[23, 432], [], [NULL]]; ---- 1 [[23,432],[],[]] + +-CASE ListHasAllTest +-STATEMENT RETURN list_has_all([4, 5, 6], [4, 6]) +---- 1 +True +-STATEMENT RETURN list_has_all([4, 5, 6], [6, 4]) +---- 1 +True +-STATEMENT RETURN list_has_all([4, 5, 6], [4]) +---- 1 +True +-STATEMENT RETURN list_has_all([4, 5, 6], [4,5,6]) +---- 1 +True +-STATEMENT RETURN list_has_all([TRUE, FALSE, FALSE], [FALSE, TRUE]) +---- 1 +True +-STATEMENT RETURN list_has_all([[2,3], [1,2], [5,4]], [[5,4]]) +---- 1 +True +-STATEMENT RETURN list_has_all([[2,3], [1,2], [5,4]], [[1,3]]) +---- 1 +False +-STATEMENT RETURN list_has_all([{a: 5, b:3}, {c: 2, d: 4}], [{c:2, d:4}]) +---- 1 +True +-STATEMENT RETURN list_has_all([{a: 5, b:3}, {c: 2, d: 4}], [{c:2, e:4}]) +---- 1 +True +-STATEMENT RETURN list_has_all([5,6,12], [null]) +---- 1 +True +-STATEMENT RETURN list_has_all([null], [null]) +---- 1 +True +-STATEMENT RETURN list_has_all([], [null]) +---- 1 +True +-STATEMENT RETURN list_has_all([], []) +---- 1 +True +-STATEMENT RETURN list_has_all([null], []) +---- 1 +True +-STATEMENT RETURN list_has_all(null, [1,3,2]) +---- 1 + +-STATEMENT RETURN list_has_all([1,2], null) +---- 1 + +-STATEMENT RETURN list_has_all(null, null) +---- 1 +