Skip to content

Commit

Permalink
Implement list_has_all (#4546)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Nov 19, 2024
1 parent 2591604 commit 6f1bf5b
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(ListTransformFunction), SCALAR_FUNCTION(ListFilterFunction),
SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction),
SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction),
SCALAR_FUNCTION(ListSingleFunction),
SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction),

// Cast functions
SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction),
Expand Down
3 changes: 2 additions & 1 deletion src/function/list/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ add_library(kuzu_list_function
list_none.cpp
list_single.cpp
size_function.cpp
quantifier_functions.cpp)
quantifier_functions.cpp
list_has_all.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_list_function>
Expand Down
67 changes: 67 additions & 0 deletions src/function/list/list_has_all.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/type_utils.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct ListHasAll {
static void operation(common::list_entry_t& left, common::list_entry_t& right, uint8_t& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& resultVector) {
int64_t pos = 0;
auto rightDataVector = ListVector::getDataVector(&rightVector);
result = true;
for (auto i = 0u; i < right.size; i++) {
common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(),
[&]<typename T>(T) {
if (rightDataVector->isNull(right.offset + i)) {
return;
}
ListPosition::operation(left,
*(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos,
leftVector, *ListVector::getDataVector(&rightVector), resultVector);
result = (pos != 0);
});
if (!result) {
return;
}
}
}
};

std::unique_ptr<FunctionBindData> bindFunc(ScalarBindFuncInput input) {
std::vector<LogicalType> types;
for (auto& arg : input.arguments) {
if (arg->dataType == LogicalType::ANY()) {
types.push_back(LogicalType::LIST(LogicalType::INT64()));
} else {
types.push_back(arg->dataType.copy());
}
}
if (types[0] != types[1]) {
throw common::BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(
ListHasAllFunction::name, input.arguments[0]->getDataType().toString(),
input.arguments[1]->getDataType().toString()));
}
return std::make_unique<FunctionBindData>(std::move(types), LogicalType::BOOL());
}

function_set ListHasAllFunction::getFunctionSet() {
function_set result;
auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
uint8_t, ListHasAll>;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::BOOL,
execFunc, bindFunc);
result.push_back(std::move(function));
return result;
}

} // namespace function
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/include/function/list/vector_list_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,11 @@ struct ListSingleFunction {
static function_set getFunctionSet();
};

struct ListHasAllFunction {
static constexpr const char* name = "LIST_HAS_ALL";

static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
53 changes: 53 additions & 0 deletions test/test_files/function/list.test
Original file line number Diff line number Diff line change
Expand Up @@ -2258,3 +2258,56 @@ a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11*a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a12*a0eebc
-STATEMENT RETURN [[23, 432], [], [NULL]];
---- 1
[[23,432],[],[]]

-CASE ListHasAllTest
-STATEMENT RETURN list_has_all([4, 5, 6], [4, 6])
---- 1
True
-STATEMENT RETURN list_has_all([4, 5, 6], [6, 4])
---- 1
True
-STATEMENT RETURN list_has_all([4, 5, 6], [4])
---- 1
True
-STATEMENT RETURN list_has_all([4, 5, 6], [4,5,6])
---- 1
True
-STATEMENT RETURN list_has_all([TRUE, FALSE, FALSE], [FALSE, TRUE])
---- 1
True
-STATEMENT RETURN list_has_all([[2,3], [1,2], [5,4]], [[5,4]])
---- 1
True
-STATEMENT RETURN list_has_all([[2,3], [1,2], [5,4]], [[1,3]])
---- 1
False
-STATEMENT RETURN list_has_all([{a: 5, b:3}, {c: 2, d: 4}], [{c:2, d:4}])
---- 1
True
-STATEMENT RETURN list_has_all([{a: 5, b:3}, {c: 2, d: 4}], [{c:2, e:4}])
---- 1
True
-STATEMENT RETURN list_has_all([5,6,12], [null])
---- 1
True
-STATEMENT RETURN list_has_all([null], [null])
---- 1
True
-STATEMENT RETURN list_has_all([], [null])
---- 1
True
-STATEMENT RETURN list_has_all([], [])
---- 1
True
-STATEMENT RETURN list_has_all([null], [])
---- 1
True
-STATEMENT RETURN list_has_all(null, [1,3,2])
---- 1

-STATEMENT RETURN list_has_all([1,2], null)
---- 1

-STATEMENT RETURN list_has_all(null, null)
---- 1

0 comments on commit 6f1bf5b

Please sign in to comment.