From 41b1a3ea0d64f8b053955f44d5696987218669da Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Thu, 9 Jan 2025 16:26:43 +0000 Subject: [PATCH] Support MPS backend for macs --- README.md | 19 ++++++++++++++++++- gradio_app.py | 14 +++++++------- spar3d/models/diffusion/gaussian_diffusion.py | 7 ++++++- .../image_estimator/clip_based_estimator.py | 9 +++++++-- spar3d/system.py | 17 +++++++++++------ texture_baker/setup.py | 3 ++- uv_unwrapper/setup.py | 3 ++- 7 files changed, 53 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index cc5e81f..99dfba8 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,9 @@ SPAR3D is based on [Stable Fast 3D](https://github.com/Stability-AI/stable-fast- Ensure your environment is: - Python >= 3.8 (Depending on PyTorch version >3.9) -- Optional: CUDA has to be available +- Optional: CUDA or MPS has to be available - For Windows **(experimental)**: Visual Studio 2022 +- For Mac (MPS) **(experimental)**: OSX 15.2 (Sequoia) and above - Has PyTorch installed according to your platform: https://pytorch.org/get-started/locally/ [Make sure the Pytorch CUDA version matches your system's.] - Update setuptools by `pip install -U setuptools==69.5.1` - Install wheel by `pip install wheel` @@ -46,12 +47,28 @@ Our model is gated at [Hugging Face](https://huggingface.co): To run SPAR3D with low VRAM mode, set the environment variable `SPAR3D_LOW_VRAM=1`. By default, SPAR3D consumes 10.5GB of VRAM. This mode will reduce the VRAM consumption to roughly 7GB but in exchange the model will be slower. The `run.py` script also supports the `--low-vram-mode` flag. +<<<<<<< Updated upstream ### Windows Support **(experimental)** To run Stable Fast 3D on Windows, you must install Visual Studio (currently tested on VS 2022) and the appropriate PyTorch and CUDA versions. Then, follow the installation steps as mentioned above. Note that Windows support is **experimental** and not guaranteed to give the same performance and/or quality as Linux. +======= +### Support for MPS (for Mac Silicon) **(experimental)** + +Stable Fast 3D can also run on Macs via the MPS backend, with the texture baker using custom metal kernels similar to the corresponding CUDA kernels. + +Support is only available for OSX 15.2 (Sequoia) and above. + +Note that support is **experimental** and not guaranteed to give the same performance and/or quality as the CUDA backend. + +MPS backend support was tested on M4 max 36GB with the latest PyTorch release and OSX 15.2 (Sequoia). We recommend you install the latest PyTorch (2.5.1 as of writing) and/or the nightly version to avoid any issues that my arise with older PyTorch versions. + +You also need to run the code with `PYTORCH_ENABLE_MPS_FALLBACK=1`. + +MPS currently consumes more memory compared to the CUDA PyTorch backend. We recommend running the CPU version if your system has less than 32GB of unified memory. +>>>>>>> Stashed changes ### CPU Support diff --git a/gradio_app.py b/gradio_app.py index a78910a..52c605c 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -83,7 +83,7 @@ def forward_model( system, guidance_scale=3.0, seed=0, - device="cuda", + device=device, remesh_option="none", vertex_count=-1, texture_resolution=1024, @@ -155,16 +155,16 @@ def run_model( if pc_cond is not None: # Check if pc_cond is a list if isinstance(pc_cond, list): - cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6) + cond_tensor = torch.tensor(pc_cond).float().to(device).view(-1, 6) xyz = cond_tensor[:, :3] color_rgb = cond_tensor[:, 3:] elif isinstance(pc_cond, dict): - xyz = torch.tensor(pc_cond["positions"]).float().cuda() - color_rgb = torch.tensor(pc_cond["colors"]).float().cuda() + xyz = torch.tensor(pc_cond["positions"]).float().to(device) + color_rgb = torch.tensor(pc_cond["colors"]).float().to(device) else: - xyz = torch.tensor(pc_cond.vertices).float().cuda() + xyz = torch.tensor(pc_cond.vertices).float().to(device) color_rgb = ( - torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0 + torch.tensor(pc_cond.colors[:, :3]).float().to(device) / 255.0 ) model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze( 0 @@ -195,7 +195,7 @@ def run_model( model, guidance_scale=guidance_scale, seed=random_seed, - device="cuda", + device=device, remesh_option=remesh_option.lower(), vertex_count=vertex_count, texture_resolution=texture_resolution, diff --git a/spar3d/models/diffusion/gaussian_diffusion.py b/spar3d/models/diffusion/gaussian_diffusion.py index 83c4060..eddcd96 100644 --- a/spar3d/models/diffusion/gaussian_diffusion.py +++ b/spar3d/models/diffusion/gaussian_diffusion.py @@ -29,6 +29,8 @@ import numpy as np import torch as th +from spar3d.utils import get_device + def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9): def sigmoid(x): @@ -169,7 +171,10 @@ def space_timesteps(num_timesteps, section_counts): def _extract_into_tensor(arr, timesteps, broadcast_shape): """Extract values from a 1-D numpy array for a batch of indices.""" - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + if get_device() == "mps": + res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps] + else: + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/spar3d/models/image_estimator/clip_based_estimator.py b/spar3d/models/image_estimator/clip_based_estimator.py index 94f1d91..c4e3caa 100644 --- a/spar3d/models/image_estimator/clip_based_estimator.py +++ b/spar3d/models/image_estimator/clip_based_estimator.py @@ -11,6 +11,7 @@ from spar3d.models.network import get_activation from spar3d.models.utils import BaseModule +from spar3d.utils import get_device OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) @@ -120,8 +121,12 @@ def forward( mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD, )(cond_image) - mask = Normalize(0.5, 0.26)(mask).half() - image_features = self.model.visual(cond_image.half(), mask).float() + if get_device() != "mps": + mask = Normalize(0.5, 0.26)(mask).half() + image_features = self.model.visual(cond_image.half(), mask).float() + else: + mask = Normalize(0.5, 0.26)(mask) + image_features = self.model.visual(cond_image, mask) # Run the heads outputs = {} diff --git a/spar3d/system.py b/spar3d/system.py index 4157986..25651bd 100644 --- a/spar3d/system.py +++ b/spar3d/system.py @@ -404,7 +404,8 @@ def _unload_pdiff_modules(self): self.pdiff_backbone = None self.diffusion_spaced = None self.sampler = None - torch.cuda.empty_cache() + if get_device() == "cuda": + torch.cuda.empty_cache() def _unload_main_modules(self): """Unload main processing modules to free memory""" @@ -414,13 +415,15 @@ def _unload_main_modules(self): self.camera_embedder = None self.backbone = None self.post_processor = None - torch.cuda.empty_cache() + if get_device() == "cuda": + torch.cuda.empty_cache() def _unload_estimator_modules(self): """Unload estimator modules to free memory""" self.image_estimator = None self.global_estimator = None - torch.cuda.empty_cache() + if get_device() == "cuda": + torch.cuda.empty_cache() def triplane_to_meshes( self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"] @@ -663,17 +666,19 @@ def generate_mesh( batch["mask_cond"], self.cfg.cond_image_size ) + device = get_device() + batch_size = batch["rgb_cond"].shape[0] if pointcloud is not None: if isinstance(pointcloud, list): - cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6) + cond_tensor = torch.tensor(pointcloud).float().to(device).view(-1, 6) xyz = cond_tensor[:, :3] color_rgb = cond_tensor[:, 3:] # Check if point cloud is a numpy array elif isinstance(pointcloud, np.ndarray): - xyz = torch.tensor(pointcloud[:, :3]).float().cuda() - color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda() + xyz = torch.tensor(pointcloud[:, :3]).float().to(device) + color_rgb = torch.tensor(pointcloud[:, 3:]).float().to(device) else: raise ValueError("Invalid point cloud type") diff --git a/texture_baker/setup.py b/texture_baker/setup.py index 8874129..b50838b 100644 --- a/texture_baker/setup.py +++ b/texture_baker/setup.py @@ -33,6 +33,7 @@ def get_extensions(): "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", "-fopenmp", + "-c++17", ] + ["-march=native"] if use_native_arch @@ -81,7 +82,7 @@ def get_extensions(): sources += glob.glob( os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True ) - extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]}) + extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]}) extra_link_args += ["-arch", "arm64"] extensions.append( diff --git a/uv_unwrapper/setup.py b/uv_unwrapper/setup.py index 62cc23e..a7d2ef0 100644 --- a/uv_unwrapper/setup.py +++ b/uv_unwrapper/setup.py @@ -29,7 +29,8 @@ def get_extensions(): ] + ["-march=native"] if use_native_arch - else [], + else [] + + ["-fno-aligned-new"] if is_mac else [], } if debug_mode: extra_compile_args["cxx"].append("-g")