Skip to content

Commit

Permalink
[Yaml]Support parsing fwd & bwd returns with name (#40107)
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 authored Mar 4, 2022
1 parent 73a4fe6 commit d2a911b
Showing 1 changed file with 14 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,39 +208,26 @@ def ParseYamlArgs(string):


def ParseYamlReturns(string):
# Example: Tensor, Tensor

# list = [ ["", ret_type, orig_position], ...]
returns_list = []

returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret_type = returns[i]

assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]

returns_list.append(["", ret_type, i])

return returns_list


def ParseYamlReturnsWithName(string):
# Example: Tensor(out), Tensor(out1)
# Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# Example2: Tensor[](out), Tensor

# list = [ [ret_name, ret_type, orig_position], ...]
returns_list = []

returns = [x.strip() for x in string.strip().split(",")]

atype = r'(.*?)'
aname = r'(.*?)'
pattern = f'{atype}\({aname}\)'
for i in range(len(returns)):
ret = returns[i]
m = re.search(pattern, ret)
ret_type = m.group(1)
ret_name = m.group(2)

ret_name = ""
if "(" in ret and ")" in ret:
# Remove trailing ')'
ret = ret[:-1]
ret_type = ret.split("(")[0].strip()
ret_name = ret.split("(")[1].strip()
else:
ret_type = ret.strip()

assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
Expand All @@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string):
function_returns = m.group(3)

forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
forward_returns_list = ParseYamlReturnsWithName(function_returns)
forward_returns_list = ParseYamlReturns(function_returns)

return forward_inputs_list, forward_attrs_list, forward_returns_list

Expand Down Expand Up @@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str):
args_str = re.search(args_pattern, args_str).group(1)

inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturnsWithName(returns_str)
returns_list = ParseYamlReturns(returns_str)

return inputs_list, attrs_list, returns_list

Expand Down

0 comments on commit d2a911b

Please sign in to comment.