Skip to content

Commit

Permalink
Support MPS backend for macs
Browse files Browse the repository at this point in the history
  • Loading branch information
jammm committed Jan 9, 2025
1 parent 977e899 commit 41b1a3e
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 19 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ SPAR3D is based on [Stable Fast 3D](https://github.com/Stability-AI/stable-fast-

Ensure your environment is:
- Python >= 3.8 (Depending on PyTorch version >3.9)
- Optional: CUDA has to be available
- Optional: CUDA or MPS has to be available
- For Windows **(experimental)**: Visual Studio 2022
- For Mac (MPS) **(experimental)**: OSX 15.2 (Sequoia) and above
- Has PyTorch installed according to your platform: https://pytorch.org/get-started/locally/ [Make sure the Pytorch CUDA version matches your system's.]
- Update setuptools by `pip install -U setuptools==69.5.1`
- Install wheel by `pip install wheel`
Expand All @@ -46,12 +47,28 @@ Our model is gated at [Hugging Face](https://huggingface.co):

To run SPAR3D with low VRAM mode, set the environment variable `SPAR3D_LOW_VRAM=1`. By default, SPAR3D consumes 10.5GB of VRAM. This mode will reduce the VRAM consumption to roughly 7GB but in exchange the model will be slower. The `run.py` script also supports the `--low-vram-mode` flag.

<<<<<<< Updated upstream
### Windows Support **(experimental)**

To run Stable Fast 3D on Windows, you must install Visual Studio (currently tested on VS 2022) and the appropriate PyTorch and CUDA versions.
Then, follow the installation steps as mentioned above.

Note that Windows support is **experimental** and not guaranteed to give the same performance and/or quality as Linux.
=======
### Support for MPS (for Mac Silicon) **(experimental)**

Stable Fast 3D can also run on Macs via the MPS backend, with the texture baker using custom metal kernels similar to the corresponding CUDA kernels.

Support is only available for OSX 15.2 (Sequoia) and above.

Note that support is **experimental** and not guaranteed to give the same performance and/or quality as the CUDA backend.

MPS backend support was tested on M4 max 36GB with the latest PyTorch release and OSX 15.2 (Sequoia). We recommend you install the latest PyTorch (2.5.1 as of writing) and/or the nightly version to avoid any issues that my arise with older PyTorch versions.

You also need to run the code with `PYTORCH_ENABLE_MPS_FALLBACK=1`.

MPS currently consumes more memory compared to the CUDA PyTorch backend. We recommend running the CPU version if your system has less than 32GB of unified memory.
>>>>>>> Stashed changes
### CPU Support

Expand Down
14 changes: 7 additions & 7 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward_model(
system,
guidance_scale=3.0,
seed=0,
device="cuda",
device=device,
remesh_option="none",
vertex_count=-1,
texture_resolution=1024,
Expand Down Expand Up @@ -155,16 +155,16 @@ def run_model(
if pc_cond is not None:
# Check if pc_cond is a list
if isinstance(pc_cond, list):
cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6)
cond_tensor = torch.tensor(pc_cond).float().to(device).view(-1, 6)
xyz = cond_tensor[:, :3]
color_rgb = cond_tensor[:, 3:]
elif isinstance(pc_cond, dict):
xyz = torch.tensor(pc_cond["positions"]).float().cuda()
color_rgb = torch.tensor(pc_cond["colors"]).float().cuda()
xyz = torch.tensor(pc_cond["positions"]).float().to(device)
color_rgb = torch.tensor(pc_cond["colors"]).float().to(device)
else:
xyz = torch.tensor(pc_cond.vertices).float().cuda()
xyz = torch.tensor(pc_cond.vertices).float().to(device)
color_rgb = (
torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0
torch.tensor(pc_cond.colors[:, :3]).float().to(device) / 255.0
)
model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(
0
Expand Down Expand Up @@ -195,7 +195,7 @@ def run_model(
model,
guidance_scale=guidance_scale,
seed=random_seed,
device="cuda",
device=device,
remesh_option=remesh_option.lower(),
vertex_count=vertex_count,
texture_resolution=texture_resolution,
Expand Down
7 changes: 6 additions & 1 deletion spar3d/models/diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import numpy as np
import torch as th

from spar3d.utils import get_device


def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
def sigmoid(x):
Expand Down Expand Up @@ -169,7 +171,10 @@ def space_timesteps(num_timesteps, section_counts):

def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""Extract values from a 1-D numpy array for a batch of indices."""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
if get_device() == "mps":
res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
else:
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
Expand Down
9 changes: 7 additions & 2 deletions spar3d/models/image_estimator/clip_based_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from spar3d.models.network import get_activation
from spar3d.models.utils import BaseModule
from spar3d.utils import get_device

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
Expand Down Expand Up @@ -120,8 +121,12 @@ def forward(
mean=OPENAI_DATASET_MEAN,
std=OPENAI_DATASET_STD,
)(cond_image)
mask = Normalize(0.5, 0.26)(mask).half()
image_features = self.model.visual(cond_image.half(), mask).float()
if get_device() != "mps":
mask = Normalize(0.5, 0.26)(mask).half()
image_features = self.model.visual(cond_image.half(), mask).float()
else:
mask = Normalize(0.5, 0.26)(mask)
image_features = self.model.visual(cond_image, mask)

# Run the heads
outputs = {}
Expand Down
17 changes: 11 additions & 6 deletions spar3d/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ def _unload_pdiff_modules(self):
self.pdiff_backbone = None
self.diffusion_spaced = None
self.sampler = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def _unload_main_modules(self):
"""Unload main processing modules to free memory"""
Expand All @@ -414,13 +415,15 @@ def _unload_main_modules(self):
self.camera_embedder = None
self.backbone = None
self.post_processor = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def _unload_estimator_modules(self):
"""Unload estimator modules to free memory"""
self.image_estimator = None
self.global_estimator = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def triplane_to_meshes(
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
Expand Down Expand Up @@ -663,17 +666,19 @@ def generate_mesh(
batch["mask_cond"], self.cfg.cond_image_size
)

device = get_device()

batch_size = batch["rgb_cond"].shape[0]

if pointcloud is not None:
if isinstance(pointcloud, list):
cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6)
cond_tensor = torch.tensor(pointcloud).float().to(device).view(-1, 6)
xyz = cond_tensor[:, :3]
color_rgb = cond_tensor[:, 3:]
# Check if point cloud is a numpy array
elif isinstance(pointcloud, np.ndarray):
xyz = torch.tensor(pointcloud[:, :3]).float().cuda()
color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda()
xyz = torch.tensor(pointcloud[:, :3]).float().to(device)
color_rgb = torch.tensor(pointcloud[:, 3:]).float().to(device)
else:
raise ValueError("Invalid point cloud type")

Expand Down
3 changes: 2 additions & 1 deletion texture_baker/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_extensions():
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
"-c++17",
]
+ ["-march=native"]
if use_native_arch
Expand Down Expand Up @@ -81,7 +82,7 @@ def get_extensions():
sources += glob.glob(
os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
)
extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]})
extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]})
extra_link_args += ["-arch", "arm64"]

extensions.append(
Expand Down
3 changes: 2 additions & 1 deletion uv_unwrapper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def get_extensions():
]
+ ["-march=native"]
if use_native_arch
else [],
else []
+ ["-fno-aligned-new"] if is_mac else [],
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
Expand Down

0 comments on commit 41b1a3e

Please sign in to comment.