Skip to content

Commit

Permalink
Merge pull request #10810 from xoiga123/fix_frontend_torch_dhvsplit
Browse files Browse the repository at this point in the history
fix frontends.torch.{h,v,d}split
  • Loading branch information
Infrared1029 authored Feb 23, 2023
2 parents 47a28b4 + 8efe3af commit 069ebcf
Showing 1 changed file with 25 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,40 @@ def split(tensor, split_size_or_sections, dim=0):
)


def _get_indices_or_sections(indices_or_sections, indices, sections):
if not ivy.exists(indices_or_sections):
if ivy.exists(indices) and not ivy.exists(sections):
indices_or_sections = indices
elif ivy.exists(sections) and not ivy.exists(indices):
indices_or_sections = sections
else:
raise ivy.utils.exception.IvyError(
"got invalid argument for indices_or_sections"
)
return indices_or_sections


@to_ivy_arrays_and_back
def dsplit(input, indices_or_sections):
def dsplit(input, indices_or_sections=None, /, *, indices=None, sections=None):
indices_or_sections = _get_indices_or_sections(
indices_or_sections, indices, sections
)
return tuple(ivy.dsplit(input, indices_or_sections))


@to_ivy_arrays_and_back
def hsplit(input, indices_or_sections):
def hsplit(input, indices_or_sections=None, /, *, indices=None, sections=None):
indices_or_sections = _get_indices_or_sections(
indices_or_sections, indices, sections
)
return tuple(ivy.hsplit(input, indices_or_sections))


@to_ivy_arrays_and_back
def vsplit(input, indices_or_sections):
def vsplit(input, indices_or_sections=None, /, *, indices=None, sections=None):
indices_or_sections = _get_indices_or_sections(
indices_or_sections, indices, sections
)
return tuple(ivy.vsplit(input, indices_or_sections))


Expand Down

0 comments on commit 069ebcf

Please sign in to comment.