-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Standardize on PyTorch for machine learning models #146
Closed
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f56b8c9
Create new module for estimating face properties using PyTorch
LoyVanBeek 5762a39
Remove python2 support, no need for it anymore
LoyVanBeek 01480a0
Download model/weights automagically
LoyVanBeek a66f0f6
Fixup some typos
LoyVanBeek 6fdb9a3
Update README.md
LoyVanBeek e68be5b
docs: Update README with right package
ar13pit 20a7c7a
build(image_recognition_pytorch): Add missing dependency
ar13pit 71abe78
fix(image_recognition_pytorch): Fix default model path
ar13pit 369f916
feat(image_recognition_pytorch): Add support for GPU
ar13pit 486dc6a
feat(image_recognition_rqt): Fix python3 imports
ar13pit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
cmake_minimum_required(VERSION 3.0.2) | ||
project(image_recognition_pytorch) | ||
|
||
find_package(catkin REQUIRED) | ||
|
||
catkin_python_setup() | ||
|
||
catkin_package() | ||
|
||
install(PROGRAMS | ||
scripts/face_properties_node | ||
scripts/get_face_properties | ||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} | ||
) | ||
|
||
if (CATKIN_ENABLE_TESTING) | ||
# Test catkin lint | ||
find_program(CATKIN_LINT catkin_lint REQUIRED) | ||
execute_process(COMMAND "${CATKIN_LINT}" "-q" "-W2" "${CMAKE_SOURCE_DIR}" RESULT_VARIABLE lint_result) | ||
if(NOT ${lint_result} EQUAL 0) | ||
message(FATAL_ERROR "catkin_lint failed") | ||
endif() | ||
|
||
catkin_add_nosetests(test) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Image recognition pytorch | ||
|
||
Image recognition (age and gender estimation of a face) with use of PyTorch. | ||
|
||
## Installation | ||
|
||
See https://github.com/tue-robotics/image_recognition for installation instructions. | ||
|
||
## ROS Node (face_properties_node) | ||
|
||
Age and gender estimation | ||
``` | ||
rosrun image_recognition_pytorch face_properties_node _weights_file_path:=[path_to_model] | ||
``` | ||
|
||
Run the image_recognition_rqt test gui (https://github.com/tue-robotics/image_recognition_rqt) | ||
|
||
rosrun image_recognition_rqt test_gui | ||
|
||
Configure the service you want to call with the gear-wheel in the top-right corner of the screen. If everything is set-up, draw a rectangle in the image around a face: | ||
|
||
![Wide ResNet](doc/wide_resnet_test.png) | ||
|
||
## Scripts | ||
|
||
### Download model | ||
|
||
Download weights from github. | ||
|
||
``` | ||
usage: download_model [-h] [--model_path MODEL_PATH] | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--model_path MODEL_PATH | ||
``` | ||
|
||
### Get face properties (get_face_properties) | ||
|
||
Get the classification result of an input image: | ||
|
||
``` | ||
rosrun image_recognition_pytorch get_face_properties `rospack find image_recognition_pytorch`/doc/face.png | ||
``` | ||
|
||
![Example](doc/face.png) | ||
|
||
Output: | ||
|
||
[(50.5418073660112, array([0.5845756 , 0.41542447], dtype=float32))] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
<?xml version="1.0"?> | ||
<?xml-model | ||
href="http://download.ros.org/schema/package_format3.xsd" | ||
schematypens="http://www.w3.org/2001/XMLSchema"?> | ||
<package format="3"> | ||
<name>image_recognition_pytorch</name> | ||
<version>0.0.1</version> | ||
<description>The image_recognition_pytorch package</description> | ||
|
||
<maintainer email="[email protected]">Loy van Beek</maintainer> | ||
|
||
<license>MIT</license> | ||
|
||
<buildtool_depend>catkin</buildtool_depend> | ||
|
||
<buildtool_depend>python3-setuptools</buildtool_depend> | ||
|
||
<exec_depend>diagnostic_updater</exec_depend> | ||
<exec_depend>image_recognition_msgs</exec_depend> | ||
<exec_depend>image_recognition_util</exec_depend> | ||
<exec_depend>python3-numpy</exec_depend> | ||
<exec_depend>python3-opencv</exec_depend> | ||
<exec_depend>python3-onnxruntime-pip</exec_depend> | ||
<exec_depend>python3-pytorch-pip</exec_depend> | ||
<exec_depend>rospy</exec_depend> | ||
|
||
<test_depend>python3-catkin-lint</test_depend> | ||
<test_depend>python3-future</test_depend> | ||
<test_depend>python3-rospkg</test_depend> | ||
|
||
<doc_depend>python3-sphinx</doc_depend> | ||
<doc_depend>python-sphinx-autoapi-pip</doc_depend> | ||
<doc_depend>python-sphinx-rtd-theme-pip</doc_depend> | ||
<doc_depend>python3-yaml</doc_depend> | ||
|
||
<export> | ||
<rosdoc config="rosdoc.yaml" /> | ||
</export> | ||
</package> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
- builder: sphinx | ||
sphinx_root_dir: docs | ||
name: Python API |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/usr/bin/env python | ||
from __future__ import print_function | ||
import os | ||
import urllib.request | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_path', default=os.path.expanduser('~/data/pytorch_models')) | ||
args = parser.parse_args() | ||
|
||
os.system('mkdir -p {}'.format(args.model_path)) | ||
local_path = os.path.join(args.model_path, 'best-epoch47-0.9314.onnx') | ||
|
||
if not os.path.exists(local_path): | ||
# TODO: Clone this for us | ||
http_path = "https://github.com/Nebula4869/PyTorch-gender-age-estimation/raw/" \ | ||
"038331d26fc1fbf24d00365d0eb9d0e5e828dda6/models-2020-11-20-14-37/best-epoch47-0.9314.onnx" | ||
print("Downloading model to {} ...".format(local_path)) | ||
urllib.request.urlretrieve(http_path, local_path) | ||
print("Model downloaded: {}".format(local_path)) | ||
else: | ||
print("Model already downloaded: {}".format(local_path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#!/usr/bin/env python | ||
import os | ||
import sys | ||
|
||
import diagnostic_updater | ||
import rospy | ||
from cv_bridge import CvBridge, CvBridgeError | ||
from image_recognition_pytorch.age_gender_estimator import AgeGenderEstimator | ||
from image_recognition_msgs.msg import FaceProperties | ||
from image_recognition_msgs.srv import GetFaceProperties | ||
from image_recognition_util import image_writer | ||
|
||
|
||
class PytorchFaceProperties: | ||
def __init__(self, weights_file_path, img_size, depth, width, save_images_folder, use_gpu): | ||
""" | ||
ROS node that wraps the PyTorch age gender estimator | ||
""" | ||
self._bridge = CvBridge() | ||
self._properties_srv = rospy.Service('get_face_properties', GetFaceProperties, self._get_face_properties_srv) | ||
self._estimator = AgeGenderEstimator(weights_file_path, img_size, depth, width, use_gpu) | ||
|
||
if save_images_folder: | ||
self._save_images_folder = os.path.expanduser(save_images_folder) | ||
if not os.path.exists(self._save_images_folder): | ||
os.makedirs(self._save_images_folder) | ||
else: | ||
self._save_images_folder = None | ||
|
||
rospy.loginfo("PytorchFaceProperties node initialized:") | ||
rospy.loginfo(" - weights_file_path=%s", weights_file_path) | ||
rospy.loginfo(" - img_size=%s", img_size) | ||
rospy.loginfo(" - depth=%s", depth) | ||
rospy.loginfo(" - width=%s", width) | ||
rospy.loginfo(" - save_images_folder=%s", save_images_folder) | ||
rospy.loginfo(" - use_gpu=%s", use_gpu) | ||
|
||
def _get_face_properties_srv(self, req): | ||
""" | ||
Callback when the GetFaceProperties service is called | ||
|
||
:param req: Input images | ||
:return: properties | ||
""" | ||
# Convert to opencv images | ||
try: | ||
bgr_images = [self._bridge.imgmsg_to_cv2(image, "bgr8") for image in req.face_image_array] | ||
except CvBridgeError as e: | ||
raise Exception("Could not convert image to opencv image: %s" % str(e)) | ||
|
||
rospy.loginfo("Estimating the age and gender of %d incoming images ...", len(bgr_images)) | ||
estimations = self._estimator.estimate(bgr_images) | ||
rospy.loginfo("Done") | ||
|
||
face_properties_array = [] | ||
for (age, gender_prob) in estimations: | ||
gender, gender_confidence = (FaceProperties.FEMALE, gender_prob[0]) if gender_prob[0] > 0.5 else (FaceProperties.MALE, gender_prob[1]) | ||
|
||
face_properties_array.append(FaceProperties( | ||
age=int(age), | ||
gender=gender, | ||
gender_confidence=gender_confidence | ||
)) | ||
|
||
# Store images if specified | ||
if self._save_images_folder: | ||
def _get_label(p): | ||
return "age_%d_gender_%s" % (p.age, "male" if p.gender == FaceProperties.MALE else "female") | ||
|
||
image_writer.write_estimations(self._save_images_folder, bgr_images, | ||
[_get_label(p) for p in face_properties_array], | ||
suffix="_face_properties") | ||
|
||
# Service response | ||
return {"properties_array": face_properties_array} | ||
|
||
|
||
if __name__ == '__main__': | ||
rospy.init_node("face_properties") | ||
|
||
try: | ||
default_weights_path = os.path.expanduser('~/data/pytorch_models/best-epoch47-0.9314.onnx') | ||
weights_file_path = rospy.get_param("~weights_file_path", default_weights_path) | ||
img_size = rospy.get_param("~image_size", 64) | ||
depth = rospy.get_param("~depth", 16) | ||
width = rospy.get_param("~width", 8) | ||
save_images = rospy.get_param("~save_images", True) | ||
use_gpu = rospy.get_param("~use_gpu", False) | ||
|
||
save_images_folder = None | ||
if save_images: | ||
save_images_folder = rospy.get_param("~save_images_folder", "/tmp/image_recognition_pytorch") | ||
except KeyError as e: | ||
rospy.logerr("Parameter %s not found" % e) | ||
sys.exit(1) | ||
|
||
try: | ||
PytorchFaceProperties(weights_file_path, img_size, depth, width, save_images_folder, use_gpu) | ||
updater = diagnostic_updater.Updater() | ||
updater.setHardwareID("none") | ||
updater.add(diagnostic_updater.Heartbeat()) | ||
rospy.Timer(rospy.Duration(1), lambda event: updater.force_update()) | ||
rospy.spin() | ||
except Exception as e: | ||
rospy.logfatal(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!/usr/bin/env python | ||
from __future__ import print_function | ||
import argparse | ||
from image_recognition_pytorch.age_gender_estimator import AgeGenderEstimator | ||
import cv2 | ||
import os | ||
|
||
# Assign description to the help doc | ||
parser = argparse.ArgumentParser(description='Get face properties using PyTorch') | ||
|
||
# Add arguments | ||
parser.add_argument('image', type=str, help='Image') | ||
parser.add_argument('--weights-path', type=str, help='Path to the weights of the WideResnet model', | ||
default=os.path.expanduser('~/data/pytorch_models/best-epoch47-0.9314.onnx')) | ||
parser.add_argument('--image-size', type=int, help='Size of the input image', default=64) | ||
parser.add_argument('--depth', type=int, help='Depth of the network', default=16) | ||
parser.add_argument('--width', type=int, help='Width of the network', default=8) | ||
|
||
args = parser.parse_args() | ||
|
||
# Read the image | ||
img = cv2.imread(args.image) | ||
|
||
estimator = AgeGenderEstimator(args.weights_path, args.image_size, args.depth, args.width) | ||
|
||
print(estimator.estimate([img])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from setuptools import setup | ||
from catkin_pkg.python_setup import generate_distutils_setup | ||
|
||
d = generate_distutils_setup( | ||
packages=['image_recognition_pytorch'], | ||
package_dir={'': 'src'} | ||
) | ||
|
||
setup(**d) |
1 change: 1 addition & 0 deletions
1
image_recognition_pytorch/src/image_recognition_pytorch/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import age_gender_estimator |
61 changes: 61 additions & 0 deletions
61
image_recognition_pytorch/src/image_recognition_pytorch/age_gender_estimator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import cv2 | ||
import numpy as np | ||
import os.path | ||
|
||
import onnxruntime | ||
|
||
GENDER_DICT = {0: 'male', 1: 'female'} | ||
|
||
|
||
class AgeGenderEstimator(object): | ||
def __init__(self, weights_file_path, img_size=64, depth=16, width=8, use_gpu=False): | ||
""" | ||
Estimate the age and gender of the incoming image | ||
|
||
:param weights_file_path: path to a pre-trained network in onnx format | ||
""" | ||
weights_file_path = os.path.expanduser(weights_file_path) | ||
|
||
if not os.path.isfile(weights_file_path): | ||
raise IOError("Weights file {}, no such file ..".format(weights_file_path)) | ||
|
||
self._model = None | ||
self._weights_file_path = weights_file_path | ||
self._img_size = img_size | ||
self._depth = depth | ||
self._width = width | ||
self._use_gpu = use_gpu | ||
|
||
def estimate(self, np_images): | ||
""" | ||
Estimate the age and gender of the face on the image | ||
|
||
:param np_images a numpy array of BGR images of faces of which the gender and the age has to be estimated | ||
This is assumed to be segmented/cropped already! | ||
:returns List of estimated age and gender score ([female, male]) tuples | ||
""" | ||
|
||
# Model should be constructed in same thread as the inference | ||
if self._model is None: | ||
providers = ['CPUExecutionProvider'] | ||
if self._use_gpu: | ||
providers.append( | ||
('CUDAExecutionProvider', { | ||
'device_id': 0, | ||
'arena_extend_strategy': 'kNextPowerOfTwo', | ||
'gpu_mem_limit': 2 * 1024 * 1024 * 1024, | ||
'cudnn_conv_algo_search': 'EXHAUSTIVE', | ||
'do_copy_in_default_stream': True, | ||
})), | ||
|
||
self._model = onnxruntime.InferenceSession(self._weights_file_path, providers=providers) | ||
|
||
results = [] | ||
for np_image in np_images: | ||
inputs = np.transpose(cv2.resize(np_image, (64, 64)), (2, 0, 1)) | ||
inputs = np.expand_dims(inputs, 0).astype(np.float32) / 255. | ||
predictions = self._model.run(['output'], input_feed={'input': inputs})[0][0] | ||
# age p(male) p(female) | ||
results += [(predictions[2], (predictions[0], predictions[1]))] | ||
|
||
return results |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/bash | ||
nosetests -vv "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Just for manual testing? As we use the catkin macro.