Skip to content

Commit

Permalink
move _process_generic to where it is called, avoid circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
dcolinmorgan committed Dec 11, 2023
1 parent c524553 commit dec561e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
39 changes: 0 additions & 39 deletions python/cuml/internals/base_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,6 @@
from cuml.internals.constants import CUML_WRAPPED_FLAG


def _process_generic(gen_type):

# Check if the type is not a generic. If not, must return "generic" if
# subtype is CumlArray otherwise None
if not isinstance(gen_type, typing._GenericAlias):
if issubclass(gen_type, CumlArray):
return "generic"

# We don't handle SparseCumlArray at this time
if issubclass(gen_type, SparseCumlArray):
raise NotImplementedError(
"Generic return types with SparseCumlArray are not supported "
"at this time"
)

# Otherwise None (keep processing)
return None

# Its a generic type by this point. Support Union, Tuple, Dict and List
supported_gen_types = [
tuple,
dict,
list,
typing.Union,
]

if gen_type.__origin__ in supported_gen_types:
# Check for a CumlArray type in the args
for arg in gen_type.__args__:
inner_type = _process_generic(arg)

if inner_type is not None:
return inner_type
else:
raise NotImplementedError("Unknow generic type: {}".format(gen_type))

return None


def _wrap_attribute(class_name: str, attribute_name: str, attribute, **kwargs):

# Skip items marked with autowrap_ignore
Expand Down
39 changes: 39 additions & 0 deletions python/cuml/internals/base_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,45 @@
from cuml.internals.array_sparse import SparseCumlArray


def _process_generic(gen_type):

# Check if the type is not a generic. If not, must return "generic" if
# subtype is CumlArray otherwise None
if not isinstance(gen_type, typing._GenericAlias):
if issubclass(gen_type, CumlArray):
return "generic"

# We don't handle SparseCumlArray at this time
if issubclass(gen_type, SparseCumlArray):
raise NotImplementedError(
"Generic return types with SparseCumlArray are not supported "
"at this time"
)

# Otherwise None (keep processing)
return None

# Its a generic type by this point. Support Union, Tuple, Dict and List
supported_gen_types = [
tuple,
dict,
list,
typing.Union,
]

if gen_type.__origin__ in supported_gen_types:
# Check for a CumlArray type in the args
for arg in gen_type.__args__:
inner_type = _process_generic(arg)

if inner_type is not None:
return inner_type
else:
raise NotImplementedError("Unknow generic type: {}".format(gen_type))

return None


def _get_base_return_type(class_name, attr):

if (
Expand Down

0 comments on commit dec561e

Please sign in to comment.