forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOnehot.cpp
69 lines (60 loc) · 2.59 KB
/
Onehot.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/eq.h>
#include <ATen/ops/one_hot_native.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
Tensor one_hot(const Tensor &self, int64_t num_classes) {
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor.");
// using meta bit test to catch Fake Tensor as well until __torch_function__
if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||
self.key_set().has_all(DispatchKeySet(DispatchKey::Python))) {
// functional version that torch.compiles better and works with dynamic shapes
if (num_classes == -1) {
num_classes = self.max().item().toLong() + 1;
}
at::Tensor index = at::arange(num_classes, self.options());
return at::eq(self.unsqueeze(-1), index).to(kLong);
}
auto shape = self.sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.push_back(num_classes);
return at::empty(shape, self.options());
}
}
// non-empty tensor
if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS &&
self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) {
// for cuda, rely on device assert thrown by scatter
TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
}
if (num_classes == -1) {
num_classes = self.max().item().toLong() + 1;
} else {
if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS &&
self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) {
// rely on device asserts from scatter to avoid sync here
TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
} else {
//for cuda, assert that num_classes is at least 1
TORCH_CHECK(num_classes >= 1, "num_classes should be positive");
}
}
shape.push_back(num_classes);
Tensor ret = at::zeros(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}
} // namespace at::native