forked from pierotofy/OpenSplat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspherical_harmonics.cpp
71 lines (60 loc) · 1.78 KB
/
spherical_harmonics.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "spherical_harmonics.hpp"
#include "vendor/gsplat/bindings.h"
int numShBases(int degree){
switch(degree){
case 0:
return 1;
case 1:
return 4;
case 2:
return 9;
case 3:
return 16;
default:
return 25;
}
}
int degFromSh(int numBases){
switch(numBases){
case 1:
return 0;
case 4:
return 1;
case 9:
return 2;
case 16:
return 3;
default:
return 4;
}
}
torch::Tensor rgb2sh(const torch::Tensor &rgb){
// Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
const double C0 = 0.28209479177387814;
return (rgb - 0.5) / C0;
}
torch::Tensor SphericalHarmonics::forward(AutogradContext *ctx,
int degreesToUse,
torch::Tensor viewDirs,
torch::Tensor coeffs){
long long numPoints = coeffs.size(0);
int degree = degFromSh(coeffs.size(-2));
ctx->saved_data["degreesToUse"] = degreesToUse;
ctx->saved_data["degree"] = degree;
ctx->save_for_backward({ viewDirs });
return compute_sh_forward_tensor(numPoints, degree, degreesToUse, viewDirs, coeffs);
}
tensor_list SphericalHarmonics::backward(AutogradContext *ctx, tensor_list grad_outputs){
torch::Tensor v_colors = grad_outputs[0];
int degreesToUse = ctx->saved_data["degreesToUse"].toInt();
int degree = ctx->saved_data["degree"].toInt();
variable_list saved = ctx->get_saved_variables();
torch::Tensor viewDirs = saved[0];
long long numPoints = v_colors.size(0);
torch::Tensor none;
return {
none,
none,
compute_sh_backward_tensor(numPoints, degree, degreesToUse, viewDirs, v_colors)
};
}