diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index ba121b7ec4fa..f09cc56de372 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1918,6 +1918,22 @@ def _impl_v1(cls, bb, inputs, attr, params): ) + relax.op.nn.relu(inputs[0]) +class HardSwish(OnnxOpConverter): + """Converts an onnx HardSwish node into an equivalent Relax expression.""" + + @classmethod + def _impl_v14(cls, bb, inputs, attr, params): + x = inputs[0] + dtype = x.struct_info.dtype + return relax.op.multiply( + x, + relax.op.divide( + relax.op.clip(relax.op.add(x, relax.const(3, dtype)), 0, 6), + relax.expr.const(6, dtype), + ), + ) + + def _get_convert_map(): return { "MatMul": MatMul, @@ -1998,6 +2014,7 @@ def _get_convert_map(): "Reciprocal": Reciprocal, "OneHot": OneHot, "Elu": Elu, + "HardSwish": HardSwish, } diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8dbd7851b0dd..0161534d17f7 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -590,6 +590,10 @@ def test_elu(): verify_unary("Elu", [32, 32]) +def test_hardswish(): + verify_unary("HardSwish", [32, 32]) + + def test_conv(): def _verify_conv(input_shape, weight_shape, output_shape): bias_shape = [output_shape[1]]