diff --git a/CMakeLists.txt b/CMakeLists.txt index 94f3409..3b5e372 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ set(PLUGIN_SOURCES src/face-tracker-monitor.cpp src/face-detector-base.cpp src/face-detector-dlib-hog.cpp + src/face-detector-dlib-cnn.cpp src/face-tracker-base.cpp src/face-tracker-dlib.cpp src/texture-object.cpp diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index e69de29..0148cfa 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -0,0 +1,2 @@ +Detector.dlib.hog="HOG, dlib" +Detector.dlib.cnn="CNN, dlib" diff --git a/src/face-detector-dlib-cnn.cpp b/src/face-detector-dlib-cnn.cpp new file mode 100644 index 0000000..d7cd9e6 --- /dev/null +++ b/src/face-detector-dlib-cnn.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include +#include "plugin-macros.generated.h" +#include "face-detector-dlib-cnn.h" +#include "texture-object.h" + +#include +#include +#include +#include + +#define MAX_ERROR 2 + +using namespace dlib; +template using con5d = con; +template using con5 = con; +template using downsampler = relu>>>>>>>>; +template using rcon5 = relu>>; +using net_type = loss_mmod>>>>>>>; +typedef dlib::matrix image_t; + +struct private_s +{ + std::shared_ptr tex; + std::vector rects; + std::string model_filename; + net_type net; + bool net_loaded = false; + bool has_error = false; + int crop_l = 0, crop_r = 0, crop_t = 0, crop_b = 0; + int n_error = 0; +}; + +face_detector_dlib_cnn::face_detector_dlib_cnn() +{ + p = new private_s; +} + +face_detector_dlib_cnn::~face_detector_dlib_cnn() +{ + delete p; +} + +void face_detector_dlib_cnn::set_texture(std::shared_ptr &tex, int crop_l, int crop_r, int crop_t, int crop_b) +{ + p->tex = tex; + p->crop_l = crop_l; + p->crop_r = crop_r; + p->crop_t = crop_t; + p->crop_b = crop_b; +} + +void face_detector_dlib_cnn::detect_main() +{ + if (!p->tex) + return; + const image_t *img = &p->tex->get_dlib_rgb_image(); + int x0 = 0, y0 = 0; + image_t img_crop; + if (p->crop_l > 0 || p->crop_r > 0 || p->crop_t > 0 || p->crop_b > 0) { + x0 = (int)(p->crop_l / p->tex->scale); + int x1 = img->nc() - (int)(p->crop_r / p->tex->scale); + y0 = (int)(p->crop_t / p->tex->scale); + int y1 = img->nr() - (int)(p->crop_b / p->tex->scale); + if (x1 - x0 < 80 || y1 - y0 < 80) { + if (p->n_error++ < MAX_ERROR) + blog(LOG_ERROR, "too small image: %dx%d cropped left=%d right=%d top=%d bottom=%d", + (int)img->nc(), (int)img->nr(), + p->crop_l, p->crop_r, p->crop_t, p->crop_b ); + return; + } + else if (p->n_error) { + p->n_error--; + } + img_crop.set_size(y1 - y0, x1 - x0); + for (int y = y0; y < y1; y++) { + for (int x = x0; x < x1; x++) { + img_crop(y-y0, x-x0) = (*img)(y, x); + } + } + img = &img_crop; + } + if (img->nc()<80 || img->nr()<80) { + if (p->n_error++ < MAX_ERROR) + blog(LOG_ERROR, "too small image: %dx%d", (int)img->nc(), (int)img->nr()); + return; + } + else if (p->n_error) { + p->n_error--; + } + + if (!p->net_loaded) { + p->net_loaded = true; + try { + blog(LOG_INFO, "loading file '%s'", p->model_filename.c_str()); + deserialize(p->model_filename.c_str()) >> p->net; + p->has_error = false; + } + catch(...) { + blog(LOG_ERROR, "failed to load file '%s'", p->model_filename.c_str()); + p->has_error = true; + } + + } + + if (p->has_error) + return; + + auto dets = p->net(*img); + p->rects.resize(dets.size()); + for (size_t i = 0; i < dets.size(); i++) { + auto &det = dets[i]; + rect_s &r = p->rects[i]; + r.x0 = (det.rect.left() + x0) * p->tex->scale; + r.y0 = (det.rect.top() + y0) * p->tex->scale; + r.x1 = (det.rect.right() + x0) * p->tex->scale; + r.y1 = (det.rect.bottom() + y0) * p->tex->scale; + r.score = det.detection_confidence; + } + + p->tex.reset(); +} + +void face_detector_dlib_cnn::get_faces(std::vector &rects) +{ + rects = p->rects; +} + +void face_detector_dlib_cnn::set_model(const char *filename) +{ + if (p->model_filename != filename) { + p->model_filename = filename; + p->net_loaded = false; + } +} diff --git a/src/face-detector-dlib-cnn.h b/src/face-detector-dlib-cnn.h new file mode 100644 index 0000000..cb90d54 --- /dev/null +++ b/src/face-detector-dlib-cnn.h @@ -0,0 +1,19 @@ +#include +#include +#include +#include "plugin-macros.generated.h" +#include "face-detector-base.h" + +class face_detector_dlib_cnn : public face_detector_base +{ + struct private_s *p; + + void detect_main() override; + public: + face_detector_dlib_cnn(); + virtual ~face_detector_dlib_cnn(); + void set_texture(std::shared_ptr &, int crop_l, int crop_r, int crop_t, int crop_b) override; + void get_faces(std::vector &) override; + + void set_model(const char *filename); +}; diff --git a/src/face-tracker-manager.cpp b/src/face-tracker-manager.cpp index 50ceae2..96d2469 100644 --- a/src/face-tracker-manager.cpp +++ b/src/face-tracker-manager.cpp @@ -2,6 +2,7 @@ #include "plugin-macros.generated.h" #include "face-tracker-manager.hpp" #include "face-detector-dlib-hog.h" +#include "face-detector-dlib-cnn.h" #include "face-tracker-dlib.h" #include "texture-object.h" #include "helper.hpp" @@ -185,6 +186,10 @@ inline void face_tracker_manager::stage_to_detector() detect->set_texture(cvtex, detector_crop_l, detector_crop_r, detector_crop_t, detector_crop_b ); + if (detector_engine == engine_dlib_cnn) { + if (auto *d = dynamic_cast(detect)) + d->set_model(detector_dlib_cnn_model.c_str()); + } detect->signal(); detector_in_progress = true; detect_tick = tick_cnt; @@ -345,6 +350,9 @@ static void update_detector(face_tracker_manager *ftm, enum face_tracker_manager case face_tracker_manager::engine_dlib_hog: ftm->detect = new face_detector_dlib_hog(); break; + case face_tracker_manager::engine_dlib_cnn: + ftm->detect = new face_detector_dlib_cnn(); + break; default: blog(LOG_ERROR, "unknown detector_engine %d", (int)detector_engine); } @@ -365,6 +373,7 @@ void face_tracker_manager::update(obs_data_t *settings) auto _detector_engine = (enum detector_engine_e)obs_data_get_int(settings, "detector_engine"); if (_detector_engine != detector_engine) update_detector(this, _detector_engine); + detector_dlib_cnn_model = obs_data_get_string(settings, "detector_dlib_cnn_model"); detector_crop_l = obs_data_get_int(settings, "detector_crop_l"); detector_crop_r = obs_data_get_int(settings, "detector_crop_r"); detector_crop_t = obs_data_get_int(settings, "detector_crop_t"); @@ -396,6 +405,12 @@ void face_tracker_manager::get_properties(obs_properties_t *pp) obs_properties_add_float(pp, "upsize_t", obs_module_text("Top"), -0.4, 4.0, 0.2); obs_properties_add_float(pp, "upsize_b", obs_module_text("Bottom"), -0.4, 4.0, 0.2); obs_properties_add_float(pp, "scale", obs_module_text("Scale image"), 1.0, 16.0, 1.0); + p = obs_properties_add_list(pp, "detector_engine", obs_module_text("Detector"), + OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); + obs_property_list_add_int(p, obs_module_text("Detector.dlib.hog"), (int)engine_dlib_hog); + obs_property_list_add_int(p, obs_module_text("Detector.dlib.cnn"), (int)engine_dlib_cnn); + obs_properties_add_path(pp, "detector_dlib_cnn_model", obs_module_text("Dlib CNN model"), + OBS_PATH_FILE, "Data Files (*.dat);;" "All Files (*.*)", obs_get_module_data_path(obs_current_module()) ); obs_properties_add_int(pp, "detector_crop_l", obs_module_text("Crop left for detector"), 0, 1920, 1); obs_properties_add_int(pp, "detector_crop_r", obs_module_text("Crop right for detector"), 0, 1920, 1); obs_properties_add_int(pp, "detector_crop_t", obs_module_text("Crop top for detector"), 0, 1080, 1); diff --git a/src/face-tracker-manager.hpp b/src/face-tracker-manager.hpp index be9036a..fbba09f 100644 --- a/src/face-tracker-manager.hpp +++ b/src/face-tracker-manager.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include "face-tracker-base.h" class face_tracker_manager @@ -8,6 +9,7 @@ class face_tracker_manager public: enum detector_engine_e { engine_dlib_hog = 0, + engine_dlib_cnn = 1, engine_uninitialized = -1, }; @@ -43,6 +45,7 @@ class face_tracker_manager volatile bool reset_requested; float tracking_threshold; enum detector_engine_e detector_engine = engine_uninitialized; + std::string detector_dlib_cnn_model; int detector_crop_l, detector_crop_r, detector_crop_t, detector_crop_b; char *landmark_detection_data; @@ -58,6 +61,7 @@ class face_tracker_manager class face_detector_base *detect; int detect_tick; + // TODO: Just have two pairs std::deque trackers; std::deque trackers_idlepool;