Skip to content

Commit

Permalink
[Torch] chunk and unsafe chunk (#8718)
Browse files Browse the repository at this point in the history
* alternative chunk op was implemented in pytorch frontend. aten::unsafe_chunk was added to op map in pytorch frontend

* chunk was replaced by new one in pytorch frontend. it is faster in 2.5 times

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Aug 13, 2021
1 parent 8843153 commit 7cf7adf
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,28 +1594,11 @@ def chunk(self, inputs, input_types):
else:
unif_size = int(dim / num_chunks)

chunks = []
for i in range(0, dim, unif_size):
begin = [0] * len(shape)
end = shape[:]
begin[axis] = i
end[axis] = i + unif_size
stride = [1] * len(shape)
indeces = []
for i in range(unif_size, dim, unif_size):
indeces.append(i)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

if dim % num_chunks:
begin = [0] * len(shape)
end = shape[:]
begin[axis] = unif_size * (num_chunks - 1)
end[axis] = dim
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

return chunks
return _op.split(data, indeces, axis)

def matmul(self, inputs, input_types):

Expand Down Expand Up @@ -2681,6 +2664,7 @@ def create_convert_map(self):
"aten::alpha_dropout": self.dropout,
"aten::mean": self.mean,
"aten::chunk": self.chunk,
"aten::unsafe_chunk": self.chunk,
"aten::matmul": self.matmul,
"aten::bmm": self.matmul,
"aten::expand": self.expand,
Expand Down

0 comments on commit 7cf7adf

Please sign in to comment.