Skip to content

Commit

Permalink
Support MPS backend for macs
Browse files Browse the repository at this point in the history
  • Loading branch information
jammm committed Jan 9, 2025
1 parent 977e899 commit 58cec73
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 19 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ SPAR3D is based on [Stable Fast 3D](https://github.com/Stability-AI/stable-fast-

Ensure your environment is:
- Python >= 3.8 (Depending on PyTorch version >3.9)
- Optional: CUDA has to be available
- Optional: CUDA or MPS has to be available
- For Windows **(experimental)**: Visual Studio 2022
- For Mac (MPS) **(experimental)**: OSX 15.2 (Sequoia) and above
- Has PyTorch installed according to your platform: https://pytorch.org/get-started/locally/ [Make sure the Pytorch CUDA version matches your system's.]
- Update setuptools by `pip install -U setuptools==69.5.1`
- Install wheel by `pip install wheel`
Expand Down Expand Up @@ -53,6 +54,20 @@ Then, follow the installation steps as mentioned above.

Note that Windows support is **experimental** and not guaranteed to give the same performance and/or quality as Linux.

### Support for MPS (for Mac Silicon) **(experimental)**

Stable Fast 3D can also run on Macs via the MPS backend, with the texture baker using custom metal kernels similar to the corresponding CUDA kernels.

Support is only available for OSX 15.2 (Sequoia) and above.

Note that support is **experimental** and not guaranteed to give the same performance and/or quality as the CUDA backend.

MPS backend support was tested on M4 max 36GB with the latest PyTorch release and OSX 15.2 (Sequoia). We recommend you install the latest PyTorch (2.5.1 as of writing) and/or the nightly version to avoid any issues that my arise with older PyTorch versions.

You also need to run the code with `PYTORCH_ENABLE_MPS_FALLBACK=1`.

MPS currently consumes more memory compared to the CUDA PyTorch backend. We recommend running the CPU version if your system has less than 32GB of unified memory.

### CPU Support

CPU backend will automatically be used if no GPU is detected in your system. Note that this will be really slow.
Expand Down
14 changes: 7 additions & 7 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward_model(
system,
guidance_scale=3.0,
seed=0,
device="cuda",
device=device,
remesh_option="none",
vertex_count=-1,
texture_resolution=1024,
Expand Down Expand Up @@ -155,16 +155,16 @@ def run_model(
if pc_cond is not None:
# Check if pc_cond is a list
if isinstance(pc_cond, list):
cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6)
cond_tensor = torch.tensor(pc_cond).float().to(device).view(-1, 6)
xyz = cond_tensor[:, :3]
color_rgb = cond_tensor[:, 3:]
elif isinstance(pc_cond, dict):
xyz = torch.tensor(pc_cond["positions"]).float().cuda()
color_rgb = torch.tensor(pc_cond["colors"]).float().cuda()
xyz = torch.tensor(pc_cond["positions"]).float().to(device)
color_rgb = torch.tensor(pc_cond["colors"]).float().to(device)
else:
xyz = torch.tensor(pc_cond.vertices).float().cuda()
xyz = torch.tensor(pc_cond.vertices).float().to(device)
color_rgb = (
torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0
torch.tensor(pc_cond.colors[:, :3]).float().to(device) / 255.0
)
model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(
0
Expand Down Expand Up @@ -195,7 +195,7 @@ def run_model(
model,
guidance_scale=guidance_scale,
seed=random_seed,
device="cuda",
device=device,
remesh_option=remesh_option.lower(),
vertex_count=vertex_count,
texture_resolution=texture_resolution,
Expand Down
7 changes: 6 additions & 1 deletion spar3d/models/diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import numpy as np
import torch as th

from spar3d.utils import get_device


def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
def sigmoid(x):
Expand Down Expand Up @@ -169,7 +171,10 @@ def space_timesteps(num_timesteps, section_counts):

def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""Extract values from a 1-D numpy array for a batch of indices."""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
if get_device() == "mps":
res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
else:
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
Expand Down
9 changes: 7 additions & 2 deletions spar3d/models/image_estimator/clip_based_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from spar3d.models.network import get_activation
from spar3d.models.utils import BaseModule
from spar3d.utils import get_device

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
Expand Down Expand Up @@ -120,8 +121,12 @@ def forward(
mean=OPENAI_DATASET_MEAN,
std=OPENAI_DATASET_STD,
)(cond_image)
mask = Normalize(0.5, 0.26)(mask).half()
image_features = self.model.visual(cond_image.half(), mask).float()
if get_device() != "mps":
mask = Normalize(0.5, 0.26)(mask).half()
image_features = self.model.visual(cond_image.half(), mask).float()
else:
mask = Normalize(0.5, 0.26)(mask)
image_features = self.model.visual(cond_image, mask)

# Run the heads
outputs = {}
Expand Down
17 changes: 11 additions & 6 deletions spar3d/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ def _unload_pdiff_modules(self):
self.pdiff_backbone = None
self.diffusion_spaced = None
self.sampler = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def _unload_main_modules(self):
"""Unload main processing modules to free memory"""
Expand All @@ -414,13 +415,15 @@ def _unload_main_modules(self):
self.camera_embedder = None
self.backbone = None
self.post_processor = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def _unload_estimator_modules(self):
"""Unload estimator modules to free memory"""
self.image_estimator = None
self.global_estimator = None
torch.cuda.empty_cache()
if get_device() == "cuda":
torch.cuda.empty_cache()

def triplane_to_meshes(
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
Expand Down Expand Up @@ -663,17 +666,19 @@ def generate_mesh(
batch["mask_cond"], self.cfg.cond_image_size
)

device = get_device()

batch_size = batch["rgb_cond"].shape[0]

if pointcloud is not None:
if isinstance(pointcloud, list):
cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6)
cond_tensor = torch.tensor(pointcloud).float().to(device).view(-1, 6)
xyz = cond_tensor[:, :3]
color_rgb = cond_tensor[:, 3:]
# Check if point cloud is a numpy array
elif isinstance(pointcloud, np.ndarray):
xyz = torch.tensor(pointcloud[:, :3]).float().cuda()
color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda()
xyz = torch.tensor(pointcloud[:, :3]).float().to(device)
color_rgb = torch.tensor(pointcloud[:, 3:]).float().to(device)
else:
raise ValueError("Invalid point cloud type")

Expand Down
3 changes: 2 additions & 1 deletion texture_baker/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_extensions():
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
"-c++17",
]
+ ["-march=native"]
if use_native_arch
Expand Down Expand Up @@ -81,7 +82,7 @@ def get_extensions():
sources += glob.glob(
os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
)
extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]})
extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]})
extra_link_args += ["-arch", "arm64"]

extensions.append(
Expand Down
3 changes: 2 additions & 1 deletion uv_unwrapper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def get_extensions():
]
+ ["-march=native"]
if use_native_arch
else [],
else []
+ ["-fno-aligned-new"] if is_mac else [],
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
Expand Down

0 comments on commit 58cec73

Please sign in to comment.