diff --git a/crates/re_types/src/datatypes/tensor_data_ext.rs b/crates/re_types/src/datatypes/tensor_data_ext.rs index 479838ac1075..0dea58904bb3 100644 --- a/crates/re_types/src/datatypes/tensor_data_ext.rs +++ b/crates/re_types/src/datatypes/tensor_data_ext.rs @@ -21,7 +21,7 @@ impl TensorData { self.shape.as_slice() } - /// Returns the shape of the tensor with all trailing dimensions of size 1 ignored. + /// Returns the shape of the tensor with all leading & trailing dimensions of size 1 ignored. /// /// If all dimension sizes are one, this returns only the first dimension. #[inline] @@ -29,12 +29,9 @@ impl TensorData { if self.shape.is_empty() { &self.shape } else { - self.shape - .iter() - .enumerate() - .rev() - .find(|(_, dim)| dim.size != 1) - .map_or(&self.shape[0..1], |(i, _)| &self.shape[..(i + 1)]) + let first_not_one = self.shape.iter().position(|dim| dim.size != 1); + let last_not_one = self.shape.iter().rev().position(|dim| dim.size != 1); + &self.shape[first_not_one.unwrap_or(0)..self.shape.len() - last_not_one.unwrap_or(0)] } } diff --git a/examples/python/depth_guided_stable_diffusion/README.md b/examples/python/depth_guided_stable_diffusion/README.md index eeae6dfaf07f..867c02c0e774 100644 --- a/examples/python/depth_guided_stable_diffusion/README.md +++ b/examples/python/depth_guided_stable_diffusion/README.md @@ -6,13 +6,12 @@ thumbnail_dimensions = [480, 253] channel = "nightly" --> - - - - - - - Depth-guided stable diffusion screenshot + + Depth-guided stable diffusion screenshot + + + + A more elaborate example running Depth Guided Stable Diffusion 2.0. diff --git a/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py b/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py index dd1d71ad5b46..eab92e43bd89 100644 --- a/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py +++ b/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py @@ -218,7 +218,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - rr.log("prompt/text_input/ids", rr.Tensor(text_input_ids)) + rr.log("prompt/text_input/ids", rr.BarChart(text_input_ids)) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): @@ -229,7 +229,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - rr.log("prompt/text_input/attention_mask", rr.Tensor(text_inputs.attention_mask)) + rr.log("prompt/text_input/attention_mask", rr.BarChart(text_inputs.attention_mask)) attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None diff --git a/examples/python/depth_guided_stable_diffusion/main.py b/examples/python/depth_guided_stable_diffusion/main.py index 214fad722d40..31712eb4dccf 100755 --- a/examples/python/depth_guided_stable_diffusion/main.py +++ b/examples/python/depth_guided_stable_diffusion/main.py @@ -17,6 +17,7 @@ import requests import rerun as rr # pip install rerun-sdk +import rerun.blueprint as rrb import torch from huggingface_pipeline import StableDiffusionDepth2ImgPipeline from PIL import Image @@ -112,7 +113,57 @@ def main() -> None: rr.script_add_args(parser) args = parser.parse_args() - rr.script_setup(args, "rerun_example_depth_guided_stable_diffusion") + rr.script_setup( + args, + "rerun_example_depth_guided_stable_diffusion", + # This example is very complex, making it too hard for the Viewer to infer a good layout. + # Therefore, we specify everything explicitly: + # We set up three columns using a `Horizontal` layout, one each for + # * inputs + # * depth & initializations + # * diffusion outputs + blueprint=rrb.Blueprint( + rrb.Horizontal( + rrb.Vertical( + rrb.Tabs( + rrb.Spatial2DView(name="Image original", origin="image/original"), + rrb.TensorView(name="Image preprocessed", origin="input_image/preprocessed"), + ), + rrb.Vertical( + rrb.TextLogView(name="Prompt", contents=["prompt/text", "prompt/text_negative"]), + rrb.Tabs( + rrb.TensorView(name="Text embeddings", origin="prompt/text_embeddings"), + rrb.TensorView(name="Unconditional embeddings", origin="prompt/uncond_embeddings"), + ), + rrb.BarChartView(name="Prompt ids", origin="prompt/text_input"), + ), + ), + rrb.Vertical( + rrb.Tabs( + rrb.Spatial2DView(name="Depth estimated", origin="depth/estimated"), + rrb.Spatial2DView(name="Depth interpolated", origin="depth/interpolated"), + rrb.Spatial2DView(name="Depth normalized", origin="depth/normalized"), + rrb.TensorView(name="Depth input pre-processed", origin="depth/input_preprocessed"), + active_tab="Depth interpolated", + ), + rrb.Tabs( + rrb.TensorView(name="Encoded input", origin="encoded_input_image"), + rrb.TensorView(name="Decoded init latents", origin="decoded_init_latents"), + ), + ), + rrb.Vertical( + rrb.Spatial2DView(name="Image diffused", origin="image/diffused"), + rrb.Horizontal( + rrb.TensorView(name="Latent Model Input", origin="diffusion/latent_model_input"), + rrb.TensorView(name="Diffusion latents", origin="diffusion/latents"), + # rrb.TensorView(name="Noise Prediction", origin="diffusion/noise_pred"), + ), + ), + ), + rrb.SelectionPanel(expanded=False), + rrb.TimePanel(expanded=False), + ), + ) image_path = args.image_path # type: str if not image_path: diff --git a/rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py b/rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py index 297a99cb7999..4e329817545c 100644 --- a/rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py +++ b/rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py @@ -23,7 +23,7 @@ def values__field_converter_override(data: TensorDataArrayLike) -> TensorDataBat # once we coerce to a canonical non-arrow type. shape_dims = tensor_data.as_arrow_array()[0].value["shape"].values.field(0).to_numpy() - if len(shape_dims) != 1: + if len([d for d in shape_dims if d != 1]) != 1: _send_warning_or_raise( f"Bar chart data should only be 1D. Got values with shape: {shape_dims}", 2,