diff --git a/nerfstudio/models/base_model.py b/nerfstudio/models/base_model.py index e1c9507ff0..febd4ab105 100644 --- a/nerfstudio/models/base_model.py +++ b/nerfstudio/models/base_model.py @@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r RGBA image. """ accumulation_name = output_name.replace("rgb", "accumulation") - if ( - not hasattr(self, "renderer_rgb") - or not hasattr(self.renderer_rgb, "background_color") - or accumulation_name not in outputs - ): + if accumulation_name not in outputs: raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}") rgb = outputs[output_name] if self.renderer_rgb.background_color == "random": # type: ignore diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 70208a45d2..4eb4a71840 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -197,6 +197,9 @@ def _render_trajectory_video( outputs = pipeline.model.get_outputs_for_camera( cameras[camera_idx : camera_idx + 1], obb_box=obb_box ) + if rendered_output_names is not None and "rgba" in rendered_output_names: + rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb") + outputs["rgba"] = rgba render_image = [] for rendered_output_name in rendered_output_names: @@ -221,6 +224,8 @@ def _render_trajectory_video( .cpu() .numpy() ) + elif rendered_output_name == "rgba": + output_image = output_image.detach().cpu().numpy() else: output_image = ( colormaps.apply_colormap( @@ -790,6 +795,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))): with torch.no_grad(): outputs = pipeline.model.get_outputs_for_camera(camera) + if self.rendered_output_names is not None and "rgba" in self.rendered_output_names: + rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb") + outputs["rgba"] = rgba gt_batch = batch.copy() gt_batch["rgb"] = gt_batch.pop("image") @@ -841,11 +849,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig: output_image = gt_batch[output_name] else: output_image = outputs[output_name] - del output_name # Map to color spaces / numpy if is_raw: output_image = output_image.cpu().numpy() + elif output_name == "rgba": + output_image = output_image.detach().cpu().numpy() elif is_depth: output_image = ( colormaps.apply_depth_colormap(