diff --git a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py index a5731b5adbf7c..de668bcc5b7a6 100644 --- a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py +++ b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py @@ -215,18 +215,40 @@ def split(tensor, split_size_or_sections, dim=0): ) +def _get_indices_or_sections(indices_or_sections, indices, sections): + if not ivy.exists(indices_or_sections): + if ivy.exists(indices) and not ivy.exists(sections): + indices_or_sections = indices + elif ivy.exists(sections) and not ivy.exists(indices): + indices_or_sections = sections + else: + raise ivy.utils.exception.IvyError( + "got invalid argument for indices_or_sections" + ) + return indices_or_sections + + @to_ivy_arrays_and_back -def dsplit(input, indices_or_sections): +def dsplit(input, indices_or_sections=None, /, *, indices=None, sections=None): + indices_or_sections = _get_indices_or_sections( + indices_or_sections, indices, sections + ) return tuple(ivy.dsplit(input, indices_or_sections)) @to_ivy_arrays_and_back -def hsplit(input, indices_or_sections): +def hsplit(input, indices_or_sections=None, /, *, indices=None, sections=None): + indices_or_sections = _get_indices_or_sections( + indices_or_sections, indices, sections + ) return tuple(ivy.hsplit(input, indices_or_sections)) @to_ivy_arrays_and_back -def vsplit(input, indices_or_sections): +def vsplit(input, indices_or_sections=None, /, *, indices=None, sections=None): + indices_or_sections = _get_indices_or_sections( + indices_or_sections, indices, sections + ) return tuple(ivy.vsplit(input, indices_or_sections))