-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathFeatureExtractorTest.cpp
112 lines (93 loc) · 4.06 KB
/
FeatureExtractorTest.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include "pch.h"
#include "CppUnitTest.h"
#include "MachineLearning/Imaging/Annotators/DepthEstimator.h"
#include "MachineLearning/Imaging/Annotators/EdgeDetector.h"
#include "MachineLearning/Imaging/Annotators/PoseEstimator.h"
#include "Storage/FileIO.h"
using namespace Axodox::Graphics;
using namespace Axodox::Storage;
using namespace Axodox::MachineLearning::Sessions;
using namespace Axodox::MachineLearning::Imaging::Annotators;
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
using namespace std;
namespace Axodox::MachineLearning::Test
{
TEST_CLASS(FeatureExtractorTest)
{
private:
inline static TextureData _imageTexture;
public:
TEST_CLASS_INITIALIZE(FeatureExtractorInitialize)
{
//Load input data
auto imagePath = lib_folder() / "..\\..\\..\\inputs\\bedroom.png";
auto imageData = read_file(imagePath);
_imageTexture = TextureData::FromBuffer(imageData);
}
TEST_METHOD(TestDepthEstimation)
{
//Prepare input
auto imageTexture = _imageTexture.Resize(512, 512);
auto imageTensor = Tensor::FromTextureData(imageTexture, ColorNormalization::LinearZeroToOne);
//Load model
auto modelPath = lib_folder() / "../../../models/annotators/depth.onnx";
DepthEstimator depthEstimator{ OnnxSessionParameters::Create(modelPath, OnnxExecutorType::Dml) };
//Run depth estimation
auto result = depthEstimator.EstimateDepth(imageTensor);
//Convert output to image
DepthEstimator::NormalizeDepthTensor(result);
auto depthTexture = result.ToTextureData(ColorNormalization::LinearZeroToOne);
auto depthData = depthTexture[0].ToBuffer();
auto outputPath = lib_folder() / "depth.png";
write_file(outputPath, depthData);
}
TEST_METHOD(TestCannyEdgeDetection)
{
//Prepare input
auto imageTensor = Tensor::FromTextureData(_imageTexture, ColorNormalization::LinearZeroToOne);
//Load model
auto modelPath = lib_folder() / "../../../models/annotators/canny.onnx";
EdgeDetector edgeDetector{ OnnxSessionParameters::Create(modelPath, OnnxExecutorType::Dml) };
//Run depth estimation
auto result = edgeDetector.DetectEdges(imageTensor);
//Convert output to image
auto values = result.AsSpan<float>();
auto edgeTexture = result.ToTextureData(ColorNormalization::LinearZeroToOne);
auto edgeData = edgeTexture[0].ToBuffer();
auto outputPath = lib_folder() / "canny.png";
write_file(outputPath, edgeData);
}
TEST_METHOD(TestHedEdgeDetection)
{
//Prepare input
auto imageTensor = Tensor::FromTextureData(_imageTexture, ColorNormalization::LinearZeroToOne);
//Load model
auto modelPath = lib_folder() / "../../../models/annotators/hed.onnx";
EdgeDetector edgeDetector{ OnnxSessionParameters::Create(modelPath, OnnxExecutorType::Dml) };
//Run depth estimation
auto result = edgeDetector.DetectEdges(imageTensor);
//Convert output to image
auto values = result.AsSpan<float>();
auto edgeTexture = result.ToTextureData(ColorNormalization::LinearZeroToOne);
auto edgeData = edgeTexture[0].ToBuffer();
auto outputPath = lib_folder() / "hed.png";
write_file(outputPath, edgeData);
}
TEST_METHOD(TestPoseDetection)
{
//Prepare input
auto imagePath = lib_folder() / "..\\..\\..\\inputs\\football.jpg";
auto imageData = read_file(imagePath);
auto imageTexture = TextureData::FromBuffer(imageData);
//Load model
auto modelPath = lib_folder() / "../../../models/annotators/openpose.onnx";
PoseEstimator poseDetector{ OnnxSessionParameters::Create(modelPath, OnnxExecutorType::Dml) };
//Run depth estimation
auto poseTexture = poseDetector.ExtractFeatures(imageTexture);
//Convert output to image
auto edgeData = poseTexture.ToBuffer();
auto outputPath = lib_folder() / "pose.png";
write_file(outputPath, edgeData);
}
};
}