forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgroup_norm.h
42 lines (35 loc) · 905 Bytes
/
group_norm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#pragma once
#include <ATen/native/DispatchStub.h>
#include <cstdint>
namespace at {
class Tensor;
namespace native {
using forward_fn = void (*)(
const Tensor& /* X */,
const Tensor& /* gamma */,
const Tensor& /* beta */,
int64_t /* N */,
int64_t /* C */,
int64_t /* HxW */,
int64_t /* group */,
double /* eps */,
Tensor& /* Y */,
Tensor& /* mean */,
Tensor& /* rstd */);
using backward_fn = void (*)(
const Tensor& /* dY */,
const Tensor& /* X */,
const Tensor& /* mean */,
const Tensor& /* rstd */,
const Tensor& /* gamma */,
int64_t /* N */,
int64_t /* C */,
int64_t /* HxW */,
int64_t /* group */,
Tensor& /* dX */,
Tensor& /* dgamma */,
Tensor& /* dbeta */);
DECLARE_DISPATCH(forward_fn, GroupNormKernel)
DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel)
} // namespace native
} // namespace at