Skip to content

Commit

Permalink
Fix bug in pixel sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
carlinds committed Apr 24, 2024
1 parent d9b022b commit d9391f9
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Code for sampling pixels.
"""

import math
import random
import warnings
from dataclasses import dataclass, field
Expand Down Expand Up @@ -302,19 +303,22 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
all_images = []
all_depth_images = []

num_rays_in_batch = num_rays_per_batch // num_images
if num_rays_in_batch % 2 != 0:
num_rays_in_batch += 1
# find the optimal number of rays per image such that it is divisible by 2 and sums to the total number of rays
num_rays_per_image = num_rays_per_batch / num_images
residual = num_rays_per_image % 2
num_rays_per_image_under = int(num_rays_per_image - residual)
num_rays_per_image_over = int(num_rays_per_image_under + 2)
num_images_under = math.ceil(num_images * (1 - residual / 2))
num_images_over = num_images - num_images_under
num_rays_per_image = num_images_under * [num_rays_per_image_under] + num_images_over * [num_rays_per_image_over]
num_rays_per_image[-1] += num_rays_per_batch - sum(num_rays_per_image)

if "mask" in batch:
for i in range(num_images):
for i, num_rays in enumerate(num_rays_per_image):
image_height, image_width, _ = batch["image"][i].shape

if i == num_images - 1:
num_rays_in_batch = num_rays_per_batch - (num_images - 1) * num_rays_in_batch

indices = self.sample_method(
num_rays_in_batch, 1, image_height, image_width, mask=batch["mask"][i].unsqueeze(0), device=device
num_rays, 1, image_height, image_width, mask=batch["mask"][i].unsqueeze(0), device=device
)
indices[:, 0] = i
all_indices.append(indices)
Expand All @@ -323,16 +327,14 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])

else:
for i in range(num_images):
for i, num_rays in enumerate(num_rays_per_image):
image_height, image_width, _ = batch["image"][i].shape
if i == num_images - 1:
num_rays_in_batch = num_rays_per_batch - (num_images - 1) * num_rays_in_batch
if self.config.is_equirectangular:
indices = self.sample_method_equirectangular(
num_rays_in_batch, 1, image_height, image_width, device=device
num_rays, 1, image_height, image_width, device=device
)
else:
indices = self.sample_method(num_rays_in_batch, 1, image_height, image_width, device=device)
indices = self.sample_method(num_rays, 1, image_height, image_width, device=device)
indices[:, 0] = i
all_indices.append(indices)
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
Expand Down

0 comments on commit d9391f9

Please sign in to comment.