-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance Device and Precision Handling, and Improve Error Messages in DepthPro Model #35
base: main
Are you sure you want to change the base?
Conversation
hey @Amael, please review this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might have issues using device
in the proposed fashion in function create_model_and_transforms
src/depth_pro/depth_pro.py
Outdated
device: torch.device = torch.device("cpu"), | ||
precision: torch.dtype = torch.float32, | ||
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | ||
precision: torch.dtype = torch.float16 if device.type == 'cuda' else torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This use of device
at this point doesn't seem to work for me.
I have to do this precision = torch.float16 if device.type == 'cuda' else torch.float32
in the body of the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running depth-pro-run -i PATH_TO_MY_IMG_FILE
, I also get a problem:
Traceback (most recent call last):
File "miniconda3/envs/depth-pro/bin/depth-pro-run", line 8, in <module>
sys.exit(run_main())
File "ml-depth-pro/src/depth_pro/cli/run.py", line 150, in main
run(parser.parse_args())
File "ml-depth-pro/src/depth_pro/cli/run.py", line 68, in run
prediction = model.infer(transform(image), f_px=f_px)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
img = t(img)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 198, in forward
return F.convert_image_dtype(image, self.dtype)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 243, in convert_image_dtype
return F_t.convert_image_dtype(image, dtype)
File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/_functional_tensor.py", line 73, in convert_image_dtype
if torch.tensor(0, dtype=dtype).is_floating_point():
TypeError: tensor(): argument 'dtype' must be torch.dtype, not tuple
@carlos-bg i have made some changes please review this. |
@carlos-bg please review |
Pull Request Description
This pull request introduces several improvements to the DepthPro model code:
Device and Precision Handling: The model now dynamically selects the appropriate device (
cuda
orcpu
) based on availability. Additionally, handling for half precision (torch.half
) has been implemented to enhance performance on compatible devices.Improved Error Messages: Enhanced error messages for loading model state dictionaries provide clearer feedback on any issues that arise during the loading process.
These enhancements aim to improve the usability and performance of the DepthPro model, making it more efficient and user-friendly.