Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SyclPlatform equality testing and hashing implemented #1333

Merged
merged 5 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h":


cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef)
cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef)
cdef DPCTLSyclPlatformRef DPCTLPlatform_Create()
cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector(
Expand All @@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef)
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef)
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
const DPCTLSyclPlatformRef)
Expand Down
3 changes: 3 additions & 0 deletions dpctl/_sycl_platform.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
SYCL platform-related helper functions.
"""

from libcpp cimport bool

from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef


Expand All @@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform):
cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef)
cdef int _init_from__SyclPlatform(self, _SyclPlatform other)
cdef DPCTLSyclPlatformRef get_platform_ref(self)
cdef bool equals(self, SyclPlatform)


cpdef list get_platforms()
40 changes: 40 additions & 0 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
""" Implements SyclPlatform Cython extension type.
"""

from libcpp cimport bool

from ._backend cimport ( # noqa: E211
DPCTLCString_Delete,
DPCTLDeviceSelector_Delete,
DPCTLFilterSelector_Create,
DPCTLPlatform_AreEq,
DPCTLPlatform_Copy,
DPCTLPlatform_Create,
DPCTLPlatform_CreateFromSelector,
Expand All @@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatform_GetPlatforms,
DPCTLPlatform_GetVendor,
DPCTLPlatform_GetVersion,
DPCTLPlatform_Hash,
DPCTLPlatformMgr_GetInfo,
DPCTLPlatformMgr_PrintInfo,
DPCTLPlatformVector_Delete,
Expand Down Expand Up @@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform):
else:
return SyclContext._create(CRef)

cdef bool equals(self, SyclPlatform other):
"""
Returns true if the :class:`dpctl.SyclPlatform` argument has the
same underlying ``DPCTLSyclPlatformRef`` object as this
:class:`dpctl.SyclPlatform` instance.

Returns:
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
``False``.
"""
return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref())

def __eq__(self, other):
"""
Returns True if the :class:`dpctl.SyclPlatform` argument has the
same underlying ``DPCTLSyclPlatformRef`` object as this
:class:`dpctl.SyclPlatform` instance.

Returns:
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
``False``.
"""
if isinstance(other, SyclPlatform):
return self.equals(<SyclPlatform> other)
else:
return False

def __hash__(self):
"""
Returns a hash value by hashing the underlying ``sycl::platform`` object.

"""
return DPCTLPlatform_Hash(self._platform_ref)


def lsplatform(verbosity=0):
"""
Expand Down
22 changes: 22 additions & 0 deletions dpctl/tests/test_sycl_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""Defines unit test cases for the SyclPlatform class.
"""

import sys

import pytest
from helper import has_sycl_platforms

Expand Down Expand Up @@ -88,17 +90,37 @@ def check_repr(platform):


def check_default_context(platform):
if "linux" not in sys.platform:
return
r = platform.default_context
assert type(r) is dpctl.SyclContext


def check_equal_and_hash(platform):
assert platform == platform
if "linux" not in sys.platform:
return
default_ctx = platform.default_context
for d in default_ctx.get_devices():
assert platform == d.sycl_platform
assert hash(platform) == hash(d.sycl_platform)


def check_hash_in_dict(platform):
map = {platform: 0}
assert map[platform] == 0


list_of_checks = [
check_name,
check_vendor,
check_version,
check_backend,
check_print_info,
check_repr,
check_default_context,
check_equal_and_hash,
check_hash_in_dict,
]


Expand Down
2 changes: 1 addition & 1 deletion libsyclinterface/include/dpctl_sycl_context_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,6 @@ void DPCTLContext_Delete(__dpctl_take DPCTLSyclContextRef CtxRef);
* @ingroup ContextInterface
*/
DPCTL_API
size_t DPCTLContext_Hash(__dpctl_take DPCTLSyclContextRef CtxRef);
size_t DPCTLContext_Hash(__dpctl_keep DPCTLSyclContextRef CtxRef);

DPCTL_C_EXTERN_C_END
23 changes: 23 additions & 0 deletions libsyclinterface/include/dpctl_sycl_platform_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ DPCTL_C_EXTERN_C_BEGIN
* @defgroup PlatformInterface Platform class C wrapper
*/

/*!
* @brief Checks if two DPCTLSyclPlatformRef objects point to the same
* sycl::platform.
*
* @param PRef1 First opaque pointer to a ``sycl::platform``.
* @param PRef2 Second opaque pointer to a ``sycl::platform``.
* @return True if the underlying sycl::platform are same, false otherwise.
* @ingroup PlatformInterface
*/
DPCTL_API
bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1,
__dpctl_keep const DPCTLSyclPlatformRef PRef2);

/*!
* @brief Returns a copy of the DPCTLSyclPlatformRef object.
*
Expand Down Expand Up @@ -155,4 +168,14 @@ DPCTL_API
__dpctl_give DPCTLSyclContextRef
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);

/*!
* @brief Wrapper over std::hash<sycl::platform>'s operator()
*
* @param PRef The DPCTLSyclPlatformRef pointer.
* @return Hash value of the underlying ``sycl::platform`` instance.
* @ingroup PlatformInterface
*/
DPCTL_API
size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef CtxRef);

DPCTL_C_EXTERN_C_END
24 changes: 24 additions & 0 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,27 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
return nullptr;
}
}

bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1,
__dpctl_keep const DPCTLSyclPlatformRef PRef2)
{
auto P1 = unwrap<platform>(PRef1);
auto P2 = unwrap<platform>(PRef2);
if (P1 && P2)
return *P1 == *P2;
else
return false;
}

size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
if (PRef) {
auto P = unwrap<platform>(PRef);
std::hash<platform> hash_fn;
return hash_fn(*P);
}
else {
error_handler("Argument PRef is null.", __FILE__, __func__, __LINE__);
return 0;
}
}
19 changes: 19 additions & 0 deletions libsyclinterface/tests/test_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,25 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkPrintInfoNullArg)
EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(Null_PRef, 0));
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEq)
{
DPCTLSyclPlatformRef PRef_Copy = nullptr;

EXPECT_NO_FATAL_FAILURE(PRef_Copy = DPCTLPlatform_Copy(PRef));

ASSERT_TRUE(DPCTLPlatform_AreEq(PRef, PRef_Copy));
EXPECT_TRUE(DPCTLPlatform_Hash(PRef) == DPCTLPlatform_Hash(PRef_Copy));

EXPECT_NO_FATAL_FAILURE(DPCTLPlatform_Delete(PRef_Copy));
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEqNullArg)
{
DPCTLSyclPlatformRef Null_PRef = nullptr;
ASSERT_FALSE(DPCTLPlatform_AreEq(PRef, Null_PRef));
ASSERT_TRUE(DPCTLPlatform_Hash(Null_PRef) == 0);
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetName)
{
check_platform_name(PRef);
Expand Down