-
Notifications
You must be signed in to change notification settings - Fork 712
/
Copy pathtnn_glint_arcface.cpp
111 lines (76 loc) · 2.5 KB
/
tnn_glint_arcface.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//
// Created by DefTruth on 2021/11/13.
//
#include "tnn_glint_arcface.h"
using tnncv::TNNGlintArcFace;
TNNGlintArcFace::TNNGlintArcFace(const std::string &_proto_path,
const std::string &_model_path,
unsigned int _num_threads) :
BasicTNNHandler(_proto_path, _model_path, _num_threads)
{
}
void TNNGlintArcFace::transform(const cv::Mat &mat_rs)
{
// push into input_mat
input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::N8UC3,
input_shape, (void *) mat_rs.data);
if (!input_mat->GetData())
{
#ifdef LITETNN_DEBUG
std::cout << "input_mat == nullptr! transform failed\n";
#endif
}
}
void TNNGlintArcFace::detect(const cv::Mat &mat, types::FaceContent &face_content)
{
if (mat.empty()) return;
// 1. make input tensor
cv::Mat mat_rs;
cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
cv::cvtColor(mat_rs, mat_rs, cv::COLOR_BGR2RGB);
this->transform(mat_rs);
// 2. set input_mat
tnn::MatConvertParam input_cvt_param;
input_cvt_param.scale = scale_vals;
input_cvt_param.bias = bias_vals;
tnn::Status status;
status = instance->SetInputMat(input_mat, input_cvt_param);
if (status != tnn::TNN_OK)
{
#ifdef LITETNN_DEBUG
std::cout << "instance->SetInputMat failed!:"
<< status.description().c_str() << "\n";
#endif
return;
}
// 3. forward
status = instance->Forward();
if (status != tnn::TNN_OK)
{
#ifdef LITETNN_DEBUG
std::cout << "instance->Forward failed!:"
<< status.description().c_str() << "\n";
#endif
return;
}
// 4. fetch output mat
std::shared_ptr<tnn::Mat> embedding_mat;
tnn::MatConvertParam embed_cvt_param; // default
status = instance->GetOutputMat(embedding_mat, embed_cvt_param, "embedding", output_device_type);
if (status != tnn::TNN_OK)
{
#ifdef LITETNN_DEBUG
std::cout << "instance->GetOutputMat failed!:"
<< status.description().c_str() << "\n";
#endif
return;
}
auto embedding_dims = embedding_mat->GetDims(); // (1,512)
const unsigned int hidden_dim = embedding_dims.at(1);
const float *embedding_values = (float *) embedding_mat->GetData();
std::vector<float> embedding_norm(embedding_values, embedding_values + hidden_dim);
cv::normalize(embedding_norm, embedding_norm); // l2 normalize
face_content.embedding.assign(embedding_norm.begin(), embedding_norm.end());
face_content.dim = hidden_dim;
face_content.flag = true;
}