-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathSpatialSkew.cu
172 lines (140 loc) · 5.44 KB
/
SpatialSkew.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include "THCUNN.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include "THCReduceApplyUtils.cuh"
#include "utils.h"
#include "common.h"
__global__ void SpatialSkew_updateOutput(
THCDeviceTensor<float, 4> input,
THCDeviceTensor<float, 4> output) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= output.getSize(2) * output.getSize(3)) {
return;
}
int outputPointX = outputPointId % output.getSize(3);
int outputPointY = outputPointId / output.getSize(3);
int offset = outputPointY;
int inputPointX = outputPointX - offset;
if (inputPointX < 0 || inputPointX >= input.getSize(3)) {
return;
}
int inputPointY = outputPointY;
float valueToCopy = input[batch][plane][inputPointY][inputPointX];
output[batch][plane][outputPointY][outputPointX] = valueToCopy;
}
static int extracunn_SpatialSkew_updateOutput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
int planeDim = 0;
int dimh = 1;
int dimw = 2;
int numBatch = 1;
int numInputDims = THCudaTensor_nDimension(state, input);
THArgCheck(numInputDims == 3 || numInputDims == 4, 2,
"input must be 3 or 4-dimensional");
if (numInputDims == 4) {
numBatch = THCudaTensor_size(state, input, 0);
planeDim++;
dimh++;
dimw++;
}
int numPlanes = THCudaTensor_size(state, input, planeDim);
int inputH = THCudaTensor_size(state, input, dimh);
int inputW = THCudaTensor_size(state, input, dimw);
int outputH = inputH;
int outputW = inputW + inputH - 1;
THCDeviceTensor<float, 4> devInput;
THCDeviceTensor<float, 4> devOutput;
if (numInputDims == 3) {
THCudaTensor_resize3d(state, output, numPlanes, outputH, outputW);
THCudaTensor_zero(state, output);
devInput = toDeviceTensor<float, 3>(state, input).upcastOuter<4>();
devOutput = toDeviceTensor<float, 3>(state, output).upcastOuter<4>();
} else {
THCudaTensor_resize4d(state, output, numBatch, numPlanes, outputH, outputW);
THCudaTensor_zero(state, output);
devInput = toDeviceTensor<float, 4>(state, input);
devOutput = toDeviceTensor<float, 4>(state, output);
}
int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.getSize(1),
devOutput.getSize(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
SpatialSkew_updateOutput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
devInput, devOutput);
return 1;
}
__global__ void SpatialSkew_updateGradInput(
THCDeviceTensor<float, 4> gradInput,
THCDeviceTensor<float, 4> gradOutput) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= gradOutput.getSize(2) * gradOutput.getSize(3)) {
return;
}
int outputPointX = outputPointId % gradOutput.getSize(3);
int outputPointY = outputPointId / gradOutput.getSize(3);
int offset = outputPointY;
int inputPointX = outputPointX - offset;
if (inputPointX < 0 || inputPointX >= gradInput.getSize(3)) {
return;
}
int inputPointY = outputPointY;
float valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
//atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
gradInput[batch][plane][inputPointY][inputPointX] = valueToCopy;
}
static int extracunn_SpatialSkew_updateGradInput(lua_State *L)
{
THCState *state = getCutorchState(L);
// Inputs
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
int planeDim = 0;
int dimh = 1;
int dimw = 2;
int numInputDims = THCudaTensor_nDimension(state, input);
if (numInputDims == 4) {
planeDim++;
dimh++;
dimw++;
}
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
THCDeviceTensor<float, 4> devGradInput;
THCDeviceTensor<float, 4> devGradOutput;
if (numInputDims == 3) {
devGradInput = toDeviceTensor<float, 3>(state, gradInput).upcastOuter<4>();
devGradOutput = toDeviceTensor<float, 3>(state, gradOutput).upcastOuter<4>();
} else {
devGradInput = toDeviceTensor<float, 4>(state, gradInput);
devGradOutput = toDeviceTensor<float, 4>(state, gradOutput);
}
int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devGradOutput.getSize(1),
devGradOutput.getSize(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
SpatialSkew_updateGradInput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
devGradInput, devGradOutput);
return 1;
}
static const struct luaL_Reg extracunn_SpatialSkew__ [] = {
{"SpatialSkew_updateOutput", extracunn_SpatialSkew_updateOutput},
{"SpatialSkew_updateGradInput", extracunn_SpatialSkew_updateGradInput},
{NULL, NULL}
};
void extracunn_SpatialSkew_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, extracunn_SpatialSkew__, "nn");
lua_pop(L,1);
}