Skip to content

Commit

Permalink
[Bugfix][Relay][Keras] Fix the wrong implementation logic about cropp…
Browse files Browse the repository at this point in the history
…ing2D (#15053)

* fix the wrong calculation logic of cropping2d

The implementation of cropping2D is wrong. This pr fix it.

* add a test case to caputure the bug

* Update test_forward.py

* Update test_forward.py

* correct the patch

* Update keras.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py
  • Loading branch information
jikechao authored Jun 15, 2023
1 parent 317ec52 commit 90b5acc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,16 @@ def _convert_cropping(
f"Operator {crop_type} is not supported for frontend Keras."
)
int32_max = np.iinfo(np.int32).max
if data_layout == "NHWC":
begin = [0, crop_t, crop_l, 0]
end = [int32_max, in_h - crop_b, in_w - crop_r, int32_max]
else:
begin = [0, 0, crop_t, crop_l]
end = [int32_max, int32_max, in_h - crop_b, in_w - crop_r]
return _op.strided_slice(
inexpr,
begin=[0, 0, crop_t, crop_l],
end=[int32_max, int32_max, in_h - crop_b, in_w - crop_r],
begin=begin,
end=end,
)


Expand Down
10 changes: 9 additions & 1 deletion tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,15 @@ def test_forward_crop(self, keras_mod):
x = keras_mod.layers.Cropping2D(cropping=0)(x)
x = keras_mod.layers.Add()([x, x])
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NHWC")

data = keras_mod.layers.Input(shape=(32, 32, 3))
x = keras_mod.layers.Cropping2D(cropping=(2, 1))(data)
x = keras_mod.layers.Cropping2D(cropping=(1, 2))(x)
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NCHW")

def test_forward_multi_inputs(self, keras_mod):
data1 = keras_mod.layers.Input(shape=(32, 32, 3))
Expand Down

0 comments on commit 90b5acc

Please sign in to comment.