-
Notifications
You must be signed in to change notification settings - Fork 159
/
Copy pathonnx_impl_utils.py
262 lines (206 loc) · 10.9 KB
/
onnx_impl_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Utility functions for onnx operator implementations."""
from typing import Callable, Tuple, Union
import numpy
from concrete.fhe import conv as fhe_conv
from concrete.fhe import ones as fhe_ones
from concrete.fhe import truncate_bit_pattern
from concrete.fhe.tracing import Tracer
from ..common.debugging import assert_true
ComparisonOperationType = Callable[[int], bool]
def numpy_onnx_pad(
x: numpy.ndarray,
pads: Tuple[int, ...],
pad_value: Union[float, int, numpy.ndarray] = 0,
int_only: bool = False,
) -> Union[numpy.ndarray, Tracer]:
"""Pad a tensor according to ONNX spec, using an optional custom pad value.
Args:
x (numpy.ndarray): input tensor to pad
pads (List[int]): padding values according to ONNX spec
pad_value (Optional[Union[float, int]]): value used to fill in padding, default 0
int_only (bool): set to True to generate integer only code with Concrete
Returns:
res(numpy.ndarray): the input tensor with padding applied
"""
x_pad = x
if numpy.any(numpy.asarray(pads) > 0):
# Weight shape is O x I x dim1 x dim2 x .. x dimN
# of which dim1 x dim2 x .. x dimN are the dimensions of the kernel (self.kernel_shape)
# I is the number of input feature maps, O is the number of output feature maps
# Need to get the number of kernel dimensions
ndim = x.ndim - 2
# Compute padded shape, keep batch size and channels number
# Add the pads to the other dimensions.
# Pads are in the form of
# dim1_start, dim2_start, ..., dimN_start, dim1_end, dim2_end, ... dimN_end
padded_shape = [x.shape[0], x.shape[1]]
padded_shape += [x.shape[i + 2] + pads[i] + pads[ndim + i] for i in range(ndim)]
# Initialize a padded version of the input, setting
# the values on the edges to the input zero_point, which corresponds
# to the real-axis 0
if int_only:
if isinstance(x_pad, Tracer):
# Quantized execution: integer mode with tracing
x_pad = fhe_ones(tuple(padded_shape)) * numpy.int64(pad_value)
else:
# Quantized execution: integer mode without tracing
x_pad = numpy.ones(padded_shape, dtype=numpy.int64) * pad_value
else:
# Calibration mode: floating-point padding for non-quantized execution
x_pad = numpy.ones(padded_shape, dtype=numpy.float32) * pad_value
assert isinstance(x_pad, (numpy.ndarray, Tracer))
# Create the indices for slice assignment, copy all on batch size and channels dimension
indices = [slice(None), slice(None)] + [
slice(pads[i], x_pad.shape[i + 2] - pads[ndim + i]) for i in range(ndim)
]
x_pad[tuple(indices)] = x
return x_pad
def compute_conv_output_dims(
input_shape: Tuple[int, ...],
kernel_shape: Tuple[int, ...],
pads: Tuple[int, ...],
strides: Tuple[int, ...],
ceil_mode: int,
) -> Tuple[int, ...]:
"""Compute the output shape of a pool or conv operation.
See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html for details
on the computation of the output shape.
Args:
input_shape (Tuple[int, ...]): shape of the input to be padded as N x C x H x W
kernel_shape (Tuple[int, ...]): shape of the conv or pool kernel, as Kh x Kw (or n-d)
pads (Tuple[int, ...]): padding values following ONNX spec:
dim1_start, dim2_start, .. dimN_start, dim1_end, dim2_end, ... dimN_end
where in the 2-d case dim1 is H, dim2 is W
strides (Tuple[int, ...]): strides for each dimension
ceil_mode (int): set to 1 to use the `ceil` function to compute the output shape, as
described in the PyTorch doc
Returns:
res (Tuple[int, ...]): shape of the output of a conv or pool operator with given parameters
"""
assert_true(ceil_mode in {0, 1})
height_out = (input_shape[2] + pads[0] + pads[2] - kernel_shape[0]) / strides[0] + 1
width_out = (input_shape[3] + pads[1] + pads[3] - kernel_shape[1]) / strides[1] + 1
if ceil_mode == 0:
height_out = numpy.floor(height_out)
width_out = numpy.floor(width_out)
else:
height_out = numpy.ceil(height_out)
width_out = numpy.ceil(width_out)
height_out = int(height_out)
width_out = int(width_out)
return (input_shape[0], input_shape[1], height_out, width_out)
def compute_onnx_pool_padding(
input_shape: Tuple[int, ...],
kernel_shape: Tuple[int, ...],
pads: Tuple[int, ...],
strides: Tuple[int, ...],
ceil_mode: int,
) -> Tuple[int, ...]:
"""Compute any additional padding needed to compute pooling layers.
The ONNX standard uses ceil_mode=1 to match TensorFlow style pooling output computation.
In this setting, the kernel can be placed at a valid position even though it contains values
outside of the input shape including padding. The ceil_mode parameter controls whether
this mode is enabled. If the mode is not enabled, the output shape follows PyTorch rules.
Args:
input_shape (Tuple[int, ...]): shape of the input to be padded as N x C x H x W
kernel_shape (Tuple[int, ...]): shape of the conv or pool kernel, as Kh x Kw (or n-d)
pads (Tuple[int, ...]): padding values following ONNX spec:
dim1_start, dim2_start, .. dimN_start, dim1_end, dim2_end, ... dimN_end
where in the 2-d case dim1 is H, dim2 is W
strides (Tuple[int, ...]): strides for each dimension
ceil_mode (int): set to 1 to use the `ceil` function to compute the output shape, as
described in the PyTorch doc
Returns:
res (Tuple[int, ...]): shape of the output of a conv or pool operator with given parameters
"""
pads2 = list(pads)
if ceil_mode == 1:
# We will pad the input with additional rows to respect TensorFlow style
# padding (ceil_mode == 1)
# Compute the dimensions for floor/ceil output computation modes
dims_floor = compute_conv_output_dims(input_shape, kernel_shape, pads, strides, 0)
dims_ceil = compute_conv_output_dims(input_shape, kernel_shape, pads, strides, 1)
# Compute the amount of additional padding necessary
# The additional padding should be done down on Y and right on X
pads2[2] += dims_ceil[2] - dims_floor[2] # pad_y_end += diff_y
pads2[3] += dims_ceil[3] - dims_floor[3] # pad_x_end += diff_x
return tuple(pads2)
def onnx_avgpool_compute_norm_const(
input_shape: Tuple[int, ...],
kernel_shape: Tuple[int, ...],
pads: Tuple[int, ...],
strides: Tuple[int, ...],
ceil_mode: int,
) -> Union[numpy.ndarray, float, Tracer]:
"""Compute the average pooling normalization constant.
This constant can be a tensor of the same shape as the input or a scalar.
Args:
input_shape (Tuple[int, ...]): shape of the input to be padded as N x C x H x W
kernel_shape (Tuple[int, ...]): shape of the conv or pool kernel, as Kh x Kw (or n-d)
pads (Tuple[int, ...]): padding values following ONNX spec:
dim1_start, dim2_start, .. dimN_start, dim1_end, dim2_end, ... dimN_end
where in the 2-d case dim1 is H, dim2 is W
strides (Tuple[int, ...]): strides for each dimension
ceil_mode (int): set to 1 to use the `ceil` function to compute the output shape, as
described in the PyTorch doc
Returns:
res (float): tensor or scalar, corresponding to normalization factors to apply for the
average pool computation for each valid kernel position
"""
# Handle the TensorFlow pooling mode
if ceil_mode == 1:
n_in_channels = input_shape[1]
kernel = numpy.ones(
(n_in_channels, 1, kernel_shape[0], kernel_shape[1]),
dtype=numpy.int64,
)
# TensorFlow (and ONNX pool with ceil_mode==1) allow the kernel of the pooling op
# to be placed in positions that include out-of-bounds indices.
# For example an input of size 2 containing values V,
# with padding P of 1 to the left and right:
# P V V P
# The pooling of size 2 can be applied at positions: 0,1,2,3:
# (P+V)/2 (V+V)/2 (V+P)/2 P
# Even though at position 3 it is out of bounds.
# When the kernel is applied with out of bounds indices, these are ignored
# and the averaging is done counting only the valid values (P or V) in its support
# We thus need to find the number of valid indices for each kernel position
# Compute the padded input tensor in Floor mode (PyTorch)
pool_pads_floor = compute_onnx_pool_padding(input_shape, kernel_shape, pads, strides, 0)
# Compute it again in TensorFlow mode
pool_pads_ceil = compute_onnx_pool_padding(input_shape, kernel_shape, pads, strides, 1)
# Create a tensor of ones for PyTorch mode and one of zeros for TF mode
padded_flr = numpy_onnx_pad(numpy.ones(input_shape, dtype=numpy.int64), pool_pads_floor, 1)
padded_ceil = numpy_onnx_pad(numpy.zeros(input_shape, dtype=numpy.int64), pool_pads_ceil, 0)
# Initialize a final tensor that has 1s in valid indices and 0s in invalid ones
padded_ceil[:, :, 0 : padded_flr.shape[2], 0 : padded_flr.shape[3]] = 1
# Compute the sum of valid indices in each kernel position
norm_const = fhe_conv(
padded_ceil, kernel, None, [0, 0, 0, 0], strides, None, None, n_in_channels
)
else:
# For the PyTorch mode, only positions with all valid indices are used so
# the averaging is done over the number of cells in the kernel
norm_const = float(numpy.prod(numpy.array(kernel_shape)))
return norm_const
def rounded_comparison(
x: numpy.ndarray, y: numpy.ndarray, lsbs_to_remove: int, operation: ComparisonOperationType
) -> Tuple[bool]:
"""Comparison operation using `round_bit_pattern` function.
`round_bit_pattern` rounds the bit pattern of an integer to the closer
It also checks for any potential overflow. If so, it readjusts the LSBs accordingly.
The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the
number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs
based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs
manually.
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove (int): Number of the least significant bits to remove
operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==`
Returns:
Tuple[bool]: If x and y satisfy the comparison operator.
"""
assert isinstance(lsbs_to_remove, int)
rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=lsbs_to_remove)
return (operation(rounded_subtraction),)