Skip to content

Commit

Permalink
[Dy2static] add set_dynamic_shape for dy2static high-level users. (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#58155)

* fix cases

* fix
  • Loading branch information
2742195759 authored and wentaoyu committed Oct 24, 2023
1 parent 165bf01 commit 9656c75
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/paddle/jit/dy2static/utils_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,14 @@ def type_from_annotation(annotation):
# raise warning if not found
warn("Currently we don't support annotation: %s" % annotation_str)
return NodeVarType.UNKNOWN


def set_dynamic_shape(variable, shape_list):
if paddle.base.dygraph.base.in_to_static_mode():
assert isinstance(
variable, paddle.base.framework.Variable
), "In to_static mode, variable must be a Variable."
variable.desc.set_shape(shape_list)
else:
# in dygraph mode, dynamic shape is not needed, just do nothing.
return
39 changes: 39 additions & 0 deletions test/dygraph_to_static/test_set_dynamic_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle


class TestSetDynamicShape(unittest.TestCase):
def test_start(self):
def dygraph_func(loop_number):
mask = paddle.randn([2, 2])
paddle.jit.dy2static.utils_helper.set_dynamic_shape(mask, [-1, 2])
n = paddle.randn([1, 2])
for i in range(loop_number):
mask = paddle.concat([mask, n], axis=0)
if mask.shape[0] == 5:
break
return mask

loop_num = paddle.to_tensor(10)
expected_shape = dygraph_func(loop_num).shape
actual_shape = paddle.jit.to_static(dygraph_func)(loop_num).shape
self.assertEqual(expected_shape, actual_shape)


if __name__ == '__main__':
unittest.main()

0 comments on commit 9656c75

Please sign in to comment.