-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathHuber.cu
41 lines (31 loc) · 916 Bytes
/
Huber.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include "THCApply.cuh"
#include "utils.h"
struct Huber {
const float threshold_;
Huber(float threshold): threshold_(threshold) {}
__device__ __forceinline__ void operator()(float* x) {
if (*x > threshold_) *x = threshold_;
else if ( *x < -threshold_) *x = -threshold_;
else *x = *x;
}
};
static int extracunn_Huber(lua_State *L)
{
THCState *state = getCutorchState(L);
double threshold = luaL_checknumber(L,2);
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 1, "torch.CudaTensor");
THC_pointwiseApply1(state, input,
Huber(threshold));
THCudaCheck(cudaGetLastError());
return 1;
}
static const struct luaL_Reg extracunn_Huber__ [] = {
{"Huber", extracunn_Huber},
{NULL, NULL}
};
void extracunn_Huber_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaL_register(L,NULL, extracunn_Huber__);
lua_pop(L,1);
}