-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv3d_same.py
90 lines (72 loc) · 3.74 KB
/
conv3d_same.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
""" Conv3d w/ Same Padding
modified from:
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/conv2d_same.py
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/padding.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
ih, iw, iz = x.size()[-3:]
pad_h = get_same_padding(ih, k[0], s[0], d[0])
pad_w = get_same_padding(iw, k[1], s[1], d[1])
pad_z = get_same_padding(iz, k[2], s[2], d[2])
if pad_h > 0 or pad_w > 0 or pad_z > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_z // 2, pad_z - pad_z // 2], value=value)
return x
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def conv3d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int, int] = (1, 1, 1),
padding: Tuple[int, int, int] = (0, 0, 0), dilation: Tuple[int, int, int] = (1, 1, 1), groups: int = 1):
x = pad_same(x, weight.shape[-3:], stride, dilation)
return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)
class Conv3dSame(nn.Conv3d):
""" Tensorflow like 'SAME' convolution wrapper for 3d convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv3dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)