diff --git a/onnx_pytorch/op_code_generators/Conv.py b/onnx_pytorch/op_code_generators/Conv.py index 7e0c2f2..534229f 100644 --- a/onnx_pytorch/op_code_generators/Conv.py +++ b/onnx_pytorch/op_code_generators/Conv.py @@ -9,54 +9,83 @@ class ConvOpCodeGenerator(OpCodeGenerator): - def __init__(self, - onnx_ver=onnx.defs.onnx_opset_version(), - torch_ver=torch.__version__): - super(ConvOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(ConvOpCodeGenerator, self).__init__(onnx_ver, torch_ver) - def gen(self, node, value_infos, initializers): - attr_value_dict = self.get_attr_value_dict(node) - inputs_str, outputs_str = self.gen_input_output_string( - node, initializers, self.rename_helper, self.tensor_inplace) + def gen(self, node, value_infos, initializers): + flag_from_layer = False + attr_value_dict = self.get_attr_value_dict(node) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper, self.tensor_inplace) - d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 - assert (d in (1, 2, 3)) + d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 + assert (d in (1, 2, 3)) - nn_name = f"Conv{d}d" - node_name = self.rename_helper.get_node_name(node.name, node.op_type) - init_str, forward_str = [], [] - padding = 0 - if "pads" in attr_value_dict: - padding = [attr_value_dict["pads"][i] for i in range(d)] - elif attr_value_dict["auto_pad"] not in (b"NOTSET", b""): - logging.warning( - "auto_pad is a DEPRECATED attribute, will not guarantee the result.") - forward_str.append( - f"{inputs_str[0]} = self.compatible_auto_pad({inputs_str[0]}, self.{node_name}.weight.data.shape[2:], self.{node_name}, '{attr_value_dict['auto_pad'].decode('utf-8')}')" - ) - weights = onnx.numpy_helper.to_array(initializers[node.input[1]]) - params_str = self.gen_params_str( - groups=attr_value_dict["group"], - dilation=attr_value_dict.get("dilations", 1), - out_channels=weights.shape[0], - padding=padding, - kernel_size=weights.shape[2:].__repr__(), - stride=attr_value_dict.get("strides", 1), - in_channels=weights.shape[1] * attr_value_dict["group"], - bias=len(node.input) > 2) + nn_name = f"Conv{d}d" + node_name = self.rename_helper.get_node_name(node.name, node.op_type) + init_str, forward_str = [], [] + padding = 0 + if "pads" in attr_value_dict: + padding = [attr_value_dict["pads"][i] for i in range(d)] + elif attr_value_dict["auto_pad"] not in (b"NOTSET", b""): + logging.warning( + "auto_pad is a DEPRECATED attribute, will not guarantee the result.") + forward_str.append( + f"{inputs_str[0]} = self.compatible_auto_pad({inputs_str[0]}, self.{node_name}.weight.data.shape[2:], self.{node_name}, '{attr_value_dict['auto_pad'].decode('utf-8')}')" + ) + weightsProto = initializers.get(node.input[1], None) + if weightsProto is None: # can't get weight arrary frome init weight + flag_from_layer = True + weightsProto = value_infos.get(inputs_str[1], None) + if weightsProto is None: + raise NotImplementedError + shape_contain = value_infos.get(inputs_str[1], None).type.tensor_type.shape.dim + shape = [] + for dim_info in shape_contain: + shape.append(dim_info.dim_value) + else: + weights_np = onnx.numpy_helper.to_array(weightsProto) + shape = weights_np.shape - init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") - init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") - if len(node.input) > 2: - init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") + if flag_from_layer is not True: + params_str = self.gen_params_str( + groups=attr_value_dict["group"], + dilation=attr_value_dict.get("dilations", 1), + out_channels=shape[0], + padding=padding, + kernel_size=shape[2:].__repr__(), + stride=attr_value_dict.get("strides", 1), + in_channels=shape[1] * attr_value_dict["group"], + bias=len(node.input) > 2) - forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") + init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") + init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") + if len(node.input) > 2: + init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") + forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") + else: + params_str = self.gen_params_str( + groups=attr_value_dict["group"], + dilation=attr_value_dict.get("dilations", 1), + # out_channels=shape[0], + padding=padding, + # kernel_size=shape[2:].__repr__(), + stride=attr_value_dict.get("strides", 1), + # in_channels=shape[1] * attr_value_dict["group"], + # bias=len(node.input) > 2 + ) - return {"init": init_str, "forward": forward_str} + forward_str.append(f"{outputs_str[0]} = F.conv2d({inputs_str[0]},{inputs_str[1]},**{{{params_str}}})") - @staticmethod - def gen_method(): - return '''def compatible_auto_pad(self, input, kernel_spatial_shape, nn_mod, auto_pad=None, **kwargs): + + + return {"init": init_str, "forward": forward_str} + + @staticmethod + def gen_method(): + return '''def compatible_auto_pad(self, input, kernel_spatial_shape, nn_mod, auto_pad=None, **kwargs): input_spatial_shape = input.shape[2:] d = len(input_spatial_shape) strides = nn_mod.stride diff --git a/onnx_pytorch/op_code_generators/ConvTranspose.py b/onnx_pytorch/op_code_generators/ConvTranspose.py index e2b493e..cdb864d 100644 --- a/onnx_pytorch/op_code_generators/ConvTranspose.py +++ b/onnx_pytorch/op_code_generators/ConvTranspose.py @@ -7,66 +7,95 @@ class ConvTransposeOpCodeGenerator(OpCodeGenerator): - def __init__(self, - onnx_ver=onnx.defs.onnx_opset_version(), - torch_ver=torch.__version__): - super(ConvTransposeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(ConvTransposeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) - def gen(self, node, value_infos, initializers): - attr_value_dict = self.get_attr_value_dict(node) - inputs_str, outputs_str = self.gen_input_output_string( - node, initializers, self.rename_helper, self.tensor_inplace) + def gen(self, node, value_infos, initializers): + flag_from_layer = False + attr_value_dict = self.get_attr_value_dict(node) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper, self.tensor_inplace) - d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 - input_size = [ - d.dim_value - for d in value_infos[node.input[0]].type.tensor_type.shape.dim - ][2:] - assert (d in (1, 2, 3)) + d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 + input_size = [ + d.dim_value + for d in value_infos[node.input[0]].type.tensor_type.shape.dim + ][2:] + assert (d in (1, 2, 3)) - weights = onnx.numpy_helper.to_array(initializers[node.input[1]]) - padding = [0] * d - output_padding = [0] * d - stride = attr_value_dict.get("strides", [1] * d) - kernel_shape = weights.shape[2:] - dilation = attr_value_dict.get("dilations", [1] * d) - if "pads" in attr_value_dict: - padding = [attr_value_dict["pads"][i] for i in range(d)] - if "output_padding" in attr_value_dict: - output_padding = [attr_value_dict["output_padding"][i] for i in range(d)] - if "output_shape" in attr_value_dict: - output_shape = attr_value_dict["output_shape"] - total_padding = [0] * d + weightsProto = initializers.get(node.input[1], None) + if weightsProto is None: # can't get weight arrary frome init weight + flag_from_layer = True + weightsProto = value_infos.get(inputs_str[1], None) + if weightsProto is None: + raise NotImplementedError + shape_contain = value_infos.get(inputs_str[1], None).type.tensor_type.shape.dim + shape = [] + for dim_info in shape_contain: + shape.append(dim_info.dim_value) + else: + weights_np = onnx.numpy_helper.to_array(weightsProto) + shape = weights_np.shape - # total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] - # If (auto_pads == SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) - # Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). + padding = [0] * d + output_padding = [0] * d + stride = attr_value_dict.get("strides", [1] * d) + kernel_shape = shape[2:] + dilation = attr_value_dict.get("dilations", [1] * d) + if "pads" in attr_value_dict: + padding = [attr_value_dict["pads"][i] for i in range(d)] + if "output_padding" in attr_value_dict: + output_padding = [attr_value_dict["output_padding"][i] for i in range(d)] + if "output_shape" in attr_value_dict: + output_shape = attr_value_dict["output_shape"] + total_padding = [0] * d - for i in range(d): - total_padding[i] = stride[i] * ( - input_size[i] - 1) + output_padding[i] + ( - (kernel_shape[i] - 1) * dilation[i] + 1) - output_shape[i] - assert total_padding[ - i] % 2 == 0, "Padding for ConvTranspose should be even." - padding[i] = total_padding[i] // 2 - params_str = self.gen_params_str(groups=attr_value_dict["group"], - dilation=dilation, - out_channels=weights.shape[1], - padding=padding, - output_padding=output_padding, - kernel_size=weights.shape[2:], - stride=stride, - in_channels=weights.shape[0], - bias=len(node.input) > 2) + # total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] + # If (auto_pads == SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) + # Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). - nn_name = f"ConvTranspose{d}d" - node_name = self.rename_helper.get_node_name(node.name, node.op_type) - init_str, forward_str = [], [] - init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") - init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") - if len(node.input) > 2: - init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") + for i in range(d): + total_padding[i] = stride[i] * ( + input_size[i] - 1) + output_padding[i] + ( + (kernel_shape[i] - 1) * dilation[i] + 1) - output_shape[i] + assert total_padding[ + i] % 2 == 0, "Padding for ConvTranspose should be even." + padding[i] = total_padding[i] // 2 - forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") + if flag_from_layer is not True: + params_str = self.gen_params_str(groups=attr_value_dict["group"], + dilation=dilation, + out_channels=shape[1], + padding=padding, + output_padding=output_padding, + kernel_size=shape[2:], + stride=stride, + in_channels=shape[0], + bias=len(node.input) > 2) - return {"init": init_str, "forward": forward_str} + nn_name = f"ConvTranspose{d}d" + node_name = self.rename_helper.get_node_name(node.name, node.op_type) + init_str, forward_str = [], [] + init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") + init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") + if len(node.input) > 2: + init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") + forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") + else: + params_str = self.gen_params_str(groups=attr_value_dict["group"], + dilation=dilation, + # out_channels=shape[1], + padding=padding, + output_padding=output_padding, + # kernel_size=shape[2:], + stride=stride, + # in_channels=shape[0], + # bias=len(node.input) > 2 + ) + init_str, forward_str = [], [] + forward_str.append(f"{outputs_str[0]} = F.conv_transpose2d({inputs_str[0]},{inputs_str[1]},**{{{params_str}}})") + + + return {"init": init_str, "forward": forward_str} diff --git a/onnx_pytorch/op_code_generators/Pad.py b/onnx_pytorch/op_code_generators/Pad.py index 1778df5..fe880a6 100644 --- a/onnx_pytorch/op_code_generators/Pad.py +++ b/onnx_pytorch/op_code_generators/Pad.py @@ -6,33 +6,43 @@ class PadOpCodeGenerator(OpCodeGenerator): - def __init__(self, - onnx_ver=onnx.defs.onnx_opset_version(), - torch_ver=torch.__version__): - super(PadOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(PadOpCodeGenerator, self).__init__(onnx_ver, torch_ver) - def gen(self, node, value_infos, initializers): - attr_value_dict = self.get_attr_value_dict(node) - inputs_str, outputs_str = self.gen_input_output_string( - node, initializers, self.rename_helper, self.tensor_inplace) - init_str, forward_str = [], [] - mode = attr_value_dict.get("mode", b"constant") - value = 0. - if mode == b"constant": - if len(node.input) == 3: - value = onnx.numpy_helper.to_array(initializers[node.input[2]])[0] - d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 - if len(node.input) > 1: - pads = initializers.get(node.input[1], None) - assert pads is not None, "Currently PadOpCodeGenerator only support all of [pads] is in initializers." - pads = onnx.numpy_helper.to_array(pads) - else: - pads = attr_value_dict["pads"] - pt_pads = [0, 0] * d - for i in range(d): - pt_pads[2 * (d - i - 1)] = pads[2 + i] - pt_pads[2 * (d - i - 1) + 1] = pads[d + 2 + 2 + i] - forward_str.append( - f"{outputs_str[0]} = F.pad({inputs_str[0]}, {pt_pads.__repr__()}, \"{mode.decode()}\", {value})" - ) - return {"init": init_str, "forward": forward_str} + def gen(self, node, value_infos, initializers): + attr_value_dict = self.get_attr_value_dict(node) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper, self.tensor_inplace) + init_str, forward_str = [], [] + mode = attr_value_dict.get("mode", b"constant") + value = 0. + if mode == b"constant": + if len(node.input) == 3: # get padding value + value = onnx.numpy_helper.to_array(initializers[node.input[2]])[0] + d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 + if len(node.input) > 1: # inputs num >= 1 actully input num == 2 + pads = initializers.get(node.input[1], None) + assert pads is not None, "Currently PadOpCodeGenerator only support all of [pads] is in initializers." + pads = onnx.numpy_helper.to_array(pads) + else: + pads = attr_value_dict["pads"] + # this method is more general + dd = len(value_infos[node.input[0]].type.tensor_type.shape.dim) + pt_pads = [0, 0] * dd + for i in range(dd): + pt_pads[2*dd-1-i*2] = pads[i] + pt_pads[2*dd-2-i*2] = pads[i+dd] + forward_str.append( + f"{outputs_str[0]} = F.pad({inputs_str[0]}, {pt_pads.__repr__()}, \"{mode.decode()}\", {value})" + ) + # original pad code ::: below + # pt_pads = [0, 0] * d + # for i in range(d): + # pt_pads[2 * (d - i - 1)] = pads[2 + i] + # pt_pads[2 * (d - i - 1) + 1] = pads[d + 2 + 2 + i] + # forward_str.append( + # f"{outputs_str[0]} = F.pad({inputs_str[0]}, {pt_pads.__repr__()}, \"{mode.decode()}\", {value})" + # ) + return {"init": init_str, "forward": forward_str} diff --git a/onnx_pytorch/op_code_generators/RandomNormal.py b/onnx_pytorch/op_code_generators/RandomNormal.py new file mode 100644 index 0000000..4cd0d0f --- /dev/null +++ b/onnx_pytorch/op_code_generators/RandomNormal.py @@ -0,0 +1,36 @@ +import onnx +import onnx.numpy_helper +import torch +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE +from onnx_pytorch.op_code_generators import OpCodeGenerator + + +class RandomNormalOpCodeGenerator(OpCodeGenerator): + + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(RandomNormalOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + + + def gen(self, node, value_infos, initializers): + attr_value_dict = self.get_attr_value_dict(node) + gen_seed=f'torch.manual_seed({attr_value_dict["seed"]})' + params_str = self.gen_params_str( + mean=attr_value_dict['mean'], + std=attr_value_dict['scale'], + size=attr_value_dict['shape'], + dtype=f"torch.{TENSOR_TYPE_TO_NP_TYPE[attr_value_dict['dtype']]}", + generator=gen_seed) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper, self.tensor_inplace) + init_str, forward_str = [], [] + forward_str.append(f"{outputs_str[0]} = torch.normal(**{{{params_str}}})") + # forward_str.append(f"{outputs_str[0]} = torch.randn(**{{{params_str}}})") + return {"init": init_str, "forward": forward_str} + + # shape = attr_value_dict['shape'] + # mean = attr_value_dict['mean'] + # std = attr_value_dict['scale'] + # seed = attr_value_dict['seed'] + # dtype = attr_value_dict['dtype'] \ No newline at end of file diff --git a/onnx_pytorch/op_code_generators/Slice.py b/onnx_pytorch/op_code_generators/Slice.py index e1df1ed..5b93bc5 100644 --- a/onnx_pytorch/op_code_generators/Slice.py +++ b/onnx_pytorch/op_code_generators/Slice.py @@ -7,67 +7,85 @@ class SliceOpCodeGenerator(OpCodeGenerator): - def __init__(self, - onnx_ver=onnx.defs.onnx_opset_version(), - torch_ver=torch.__version__): - super(SliceOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(SliceOpCodeGenerator, self).__init__(onnx_ver, torch_ver) - def gen(self, node, value_infos, initializers): - attr_value_dict = self.get_attr_value_dict(node) - inputs_str, outputs_str = self.gen_input_output_string( - node, initializers, self.rename_helper) - init_str, forward_str = [], [] - d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - starts, ends, axes, steps = self._get_starts_ends_axes_steps( - attr_value_dict, d, node, initializers) - slice_str = [] - for i in range(d): - if i in axes: - j = axes.index(i) - s = ["", ""] - if type(starts) == str and type(ends) == str: - s[0] = f'{starts}[{j}] if {starts}[{j}]' - s[1] = f'{ends}[{j}] if {ends}[{j}]' - else: - s = [ - str(starts[j]) if starts[j] != 0 else "", - str(ends[j]) if ends[j] < 2**31 else "" - ] - if steps[j] != 1: - s.append(str(steps[j])) - slice_str.append(":".join(s)) - else: - slice_str.append(":") + def gen(self, node, value_infos, initializers): + attr_value_dict = self.get_attr_value_dict(node) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper) + init_str, forward_str = [], [] + d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) + starts, ends, axes, steps = self._get_starts_ends_axes_steps( + attr_value_dict, d, node, initializers) + slice_str = [] + flip_dims = [] + for i in range(d): + if i in axes: + j = axes.index(i) + s = ["", ""] + if steps[j] < 0: # need flip tensor + flip_dims.append(i) + temp_step = -steps[j] + # exchange the start and end points when step is negative + if type(starts) == str and type(ends) == str: + s[1] = f'{starts}[{j}] if {starts}[{j}]' + s[0] = f'{ends}[{j}] if {ends}[{j}]' + else: + s = [ + str(ends[j]) if ends[j] < 2 ** 31 else "", + str(starts[j]) if starts[j] != 0 else "" + ] + else: + temp_step = steps[j] + if type(starts) == str and type(ends) == str: + s[0] = f'{starts}[{j}] if {starts}[{j}]' + s[1] = f'{ends}[{j}] if {ends}[{j}]' + else: + s = [ + str(starts[j]) if starts[j] != 0 else "", + str(ends[j]) if ends[j] < 2 ** 31 else "" + ] + if temp_step != 1: + s.append(str(temp_step)) + slice_str.append(":".join(s)) + else: + slice_str.append(":") - forward_str.append( - f"{outputs_str[0]} = {inputs_str[0]}[{', '.join(slice_str)}]") - return {"init": init_str, "forward": forward_str} + forward_str.append( + f"{outputs_str[0]} = {inputs_str[0]}[{', '.join(slice_str)}]") + if len(flip_dims) > 0: + forward_str.append( + f"{outputs_str[0]} = {outputs_str[0]}.flip({flip_dims})") + return {"init": init_str, "forward": forward_str} - def _get_starts_ends_axes_steps(self, attr_value_dict, d, node, initializers): - axes = list(range(d)) - steps = [1] * len(axes) - if self.onnx_ver > 1 and len(node.input) > 1: - starts = initializers.get(node.input[1], None) - ends = initializers.get(node.input[2], None) - if starts is None: - starts = node.input[1] - else: - starts = to_array(starts) - if ends is None: - ends = node.input[2] - else: - ends = to_array(ends) - if len(node.input) > 3: - axes = initializers.get(node.input[3], None) - if len(node.input) > 4: - steps = initializers.get(node.input[4], None) - assert starts is not None or ends is not None or axes is not None or steps is not None, "Currently SliceOpCodeGenerator only support all of [starts, ends, axes, steps] is in initializers." - if len(node.input) > 3: - axes = to_array(axes) - if len(node.input) > 4: - steps = to_array(steps) - else: - starts = attr_value_dict["starts"] - ends = attr_value_dict["ends"] - axes = attr_value_dict.get("axes", axes) - return starts, ends, list(axes), list(steps) + def _get_starts_ends_axes_steps(self, attr_value_dict, d, node, initializers): + axes = list(range(d)) + steps = [1] * len(axes) + if self.onnx_ver > 1 and len(node.input) > 1: + starts = initializers.get(node.input[1], None) + ends = initializers.get(node.input[2], None) + if starts is None: + starts = node.input[1] + else: + starts = to_array(starts) + if ends is None: + ends = node.input[2] + else: + ends = to_array(ends) + if len(node.input) > 3: + axes = initializers.get(node.input[3], None) + if len(node.input) > 4: + steps = initializers.get(node.input[4], None) + assert starts is not None or ends is not None or axes is not None or steps is not None, "Currently SliceOpCodeGenerator only support all of [starts, ends, axes, steps] is in initializers." + if len(node.input) > 3: + axes = to_array(axes) + if len(node.input) > 4: + steps = to_array(steps) + else: + starts = attr_value_dict["starts"] + ends = attr_value_dict["ends"] + axes = attr_value_dict.get("axes", axes) + return starts, ends, list(axes), list(steps) diff --git a/onnx_pytorch/op_code_generators/Tile.py b/onnx_pytorch/op_code_generators/Tile.py new file mode 100644 index 0000000..bea1d51 --- /dev/null +++ b/onnx_pytorch/op_code_generators/Tile.py @@ -0,0 +1,36 @@ +import onnx +import onnx.numpy_helper +import torch + +from onnx_pytorch.op_code_generators import OpCodeGenerator + + +class TileOpCodeGenerator(OpCodeGenerator): + + def __init__(self, + onnx_ver=onnx.defs.onnx_opset_version(), + torch_ver=torch.__version__): + super(TileOpCodeGenerator, self).__init__(onnx_ver, torch_ver) + + def gen(self, node, value_infos, initializers): + # attr_value_dict = self.get_attr_value_dict(node) + inputs_str, outputs_str = self.gen_input_output_string( + node, initializers, self.rename_helper, self.tensor_inplace) + + init_str, forward_str = [], [] + repeats_np = onnx.numpy_helper.to_array(initializers[node.input[1]]) + repeats = tuple(i for i in repeats_np) + params_str = self.gen_params_str(dims=repeats) + temp_str = f"{', '.join(outputs_str)} = torch.tile({inputs_str[0]}, **{{{params_str}}})" + forward_str.append(temp_str) + # forward_str.append(f"{outputs_str[0]} = torch.tile({inputs_str[0]}, {repeats})") + + # forward_str.append(f"{outputs_str[0]} = torch.tile({', '.join(inputs_str)})") + # forward_str.append(f"{outputs_str[0]} = torch.normal(**{{{params_str}}})") + return {"init": init_str, "forward": forward_str} + + # shape = attr_value_dict['shape'] + # mean = attr_value_dict['mean'] + # std = attr_value_dict['scale'] + # seed = attr_value_dict['seed'] + # dtype = attr_value_dict['dtype']