Skip to content

Commit

Permalink
PixelSampler Fixes (#2298)
Browse files Browse the repository at this point in the history
* monodepth implemented

* device

* pair pixel sampler

* docstring

* removing unused imports

* fix bugs

* cache in cache dir, no longer edit json file

* moving to depth nerfacto

* fixed bug using median depth by accident

* removing depthdataset from nerfacto

* added pixelsamplerconfig

* added equirectangular warning back

* bug fix

* delete sample_config.py + small fixes

* rename to pixel_sampler

* cleaning

* test

* passing num_rays_per_batch properly

* fixed transforms typo

* renaming in nerfacto

* remove depth files from test data

* removed defaulting to base pixel sampler with mask

* adding patch_size back

* Warning for notimplemented error and patch_size fix

* fix config.pixel_sampler check

---------

Co-authored-by: Matthew Tancik <[email protected]>
Co-authored-by: Ethan Weber <[email protected]>
Co-authored-by: Ethan Weber <[email protected]>
  • Loading branch information
4 people authored Aug 4, 2023
1 parent 9b03299 commit 7e90cb7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def create_eval_dataset(self) -> TDataset:

def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler:
"""Infer pixel sampler to use."""
if self.config.patch_size > 1:
if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig:
return PatchPixelSamplerConfig().setup(
patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch
)
Expand Down
13 changes: 7 additions & 6 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ def sample_method(
device: Union[torch.device, str] = "cpu",
) -> Int[Tensor, "batch_size 3"]:
if isinstance(mask, Tensor):
# Note: if there is a mask, should switch to the base PixelSampler class
raise NotImplementedError()
raise NotImplementedError(
"Masked sampling not implemented for PairPixelSampler. Change Config to PixelSamplerConfig instead."
)
else:
sub_bs = batch_size // (self.config.patch_size**2)
indices = torch.rand((sub_bs, 3), device=device) * torch.tensor(
Expand Down Expand Up @@ -379,10 +380,10 @@ def sample_method( # pylint: disable=no-self-use
mask: Optional[Tensor] = None,
device: Union[torch.device, str] = "cpu",
) -> Int[Tensor, "batch_size 3"]:
if mask:
# Note: if there is a mask, should switch to the base PixelSampler class

raise NotImplementedError()
if isinstance(mask, Tensor):
raise NotImplementedError(
"Masked sampling not implemented for PairPixelSampler. Change Config to PixelSamplerConfig instead."
)
else:
s = (self.rays_to_sample, 1)
ns = torch.randint(0, num_images, s, dtype=torch.long, device=device)
Expand Down

0 comments on commit 7e90cb7

Please sign in to comment.