diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index de956c243e..d9bad9839c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,25 @@ on: - "v*" jobs: + publish-burn-vision: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-vision + needs: + - publish-burn-autodiff + - publish-burn-candle + - publish-burn-fusion + - publish-burn-jit + - publish-burn-ndarray + - publish-burn-tch + - publish-burn-tensor + - publish-burn-tensor-testgen + # dev dependencies + - publish-burn-wgpu + - publish-burn-cuda + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-router: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: diff --git a/Cargo.lock b/Cargo.lock index 3fa040623f..6483e18960 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -900,6 +900,25 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "burn-vision" +version = "0.17.0" +dependencies = [ + "burn-candle", + "burn-cuda", + "burn-fusion", + "burn-jit", + "burn-ndarray", + "burn-tch", + "burn-tensor", + "burn-tensor-testgen", + "burn-wgpu", + "cubecl", + "derive-new 0.7.0", + "ndarray 0.16.1", + "serde", +] + [[package]] name = "burn-wgpu" version = "0.17.0" diff --git a/crates/burn-candle/src/element.rs b/crates/burn-candle/src/element.rs index df5c2dc756..ebe4a056c2 100644 --- a/crates/burn-candle/src/element.rs +++ b/crates/burn-candle/src/element.rs @@ -4,8 +4,11 @@ use burn_tensor::Element; use candle_core::{FloatDType, Tensor, WithDType}; use half::{bf16, f16}; +/// Candle element pub trait CandleElement: Element + WithDType {} +/// Candle float element pub trait FloatCandleElement: CandleElement + FloatDType {} +/// Candle int element pub trait IntCandleElement: CandleElement {} impl CandleElement for f64 {} diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 78923fad03..64a6d05330 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -13,6 +13,7 @@ mod ops; mod tensor; pub use backend::*; +pub use element::*; pub use tensor::*; #[cfg(test)] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 1a92e695b2..290e79b804 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" version.workspace = true [features] -default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"] autotune = ["burn-jit/autotune"] +default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] diff --git a/crates/burn-jit/src/kernel/index/mod.rs b/crates/burn-jit/src/kernel/index/mod.rs index 828c39c50c..83ce64aff8 100644 --- a/crates/burn-jit/src/kernel/index/mod.rs +++ b/crates/burn-jit/src/kernel/index/mod.rs @@ -11,7 +11,7 @@ pub(crate) use flip::*; pub(crate) use repeat_dim::*; pub(crate) use select::*; pub(crate) use select_assign::*; -pub(crate) use slice::*; +pub use slice::*; pub(crate) use slice_assign::*; pub(crate) use gather::*; diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index b6daba8da5..bca8e00dd9 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -3,7 +3,8 @@ use burn_tensor::Shape; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use std::ops::Range; -pub(crate) fn slice( +/// Slice a jit tensor with a set of ranges +pub fn slice( tensor: JitTensor, indices: &[Range], ) -> JitTensor { diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 93d2833976..e1a0a3158e 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -39,4 +39,4 @@ pub mod reduce; pub(crate) use clamp::*; pub(crate) use comparison::*; -pub(crate) use index::*; +pub use index::*; diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index acf69d9aec..ae15fb945f 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -7,7 +7,8 @@ extern crate derive_new; extern crate alloc; -mod ops; +/// Utilities for implementing JIT kernels +pub mod ops; /// Kernel module pub mod kernel; diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 645aaf1535..112a11de33 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -76,6 +76,7 @@ pub(crate) fn swap_dims( tensor } +/// Permute a tensor's dimensions pub fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { // remap strides tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); @@ -135,7 +136,8 @@ pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) - } } -pub(crate) fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { +/// Reshape a jit tensor to a new shape +pub fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { // TODO: Not force standard layout all the time (improve performance). let tensor = kernel::into_contiguous(tensor); diff --git a/crates/burn-jit/src/ops/mod.rs b/crates/burn-jit/src/ops/mod.rs index 2e23e3835d..c396bdacdd 100644 --- a/crates/burn-jit/src/ops/mod.rs +++ b/crates/burn-jit/src/ops/mod.rs @@ -7,6 +7,7 @@ mod qtensor; mod transaction; pub(crate) mod base; -pub(crate) use base::*; +pub use base::*; -pub(crate) mod numeric; +/// Numeric utility functions for jit backends +pub mod numeric; diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index cf15916aab..432276ccb6 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -9,6 +9,7 @@ use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; use cubecl::{calculate_cube_count_elemwise, prelude::*}; +/// Create a tensor filled with `value` pub fn full( shape: Shape, device: &R::Device, @@ -19,6 +20,7 @@ pub fn full( full_device::(client, shape, device.clone(), value) } +/// Create a tensor filled with `value` pub fn full_device( client: ComputeClient, shape: Shape, @@ -56,12 +58,14 @@ pub fn full_device( empty } +/// Create a tensor filled with zeros pub fn zeros(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); zeros_device::(client, device.clone(), shape) } +/// Create a tensor filled with zeros pub fn zeros_device( client: ComputeClient, device: R::Device, @@ -70,12 +74,14 @@ pub fn zeros_device( full_device::(client, shape, device, 0.elem()) } +/// Create a tensor filled with ones pub fn ones(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); ones_device::(client, device.clone(), shape) } +/// Create a tensor filled with ones pub fn ones_device( client: ComputeClient, device: R::Device, @@ -84,6 +90,7 @@ pub fn ones_device( full_device::(client, shape, device, 1.elem()) } +/// Create a tensor with uninitialized memory pub fn empty_device( client: ComputeClient, device: R::Device, @@ -94,38 +101,47 @@ pub fn empty_device( JitTensor::new_contiguous(client, device, shape, buffer, E::dtype()) } +/// Add two tensors pub fn add(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Add a tensor and a scalar pub fn add_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Subtract two tensors pub fn sub(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Subtract a tensor and a scalar pub fn sub_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Multiply two tensors pub fn mul(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Multiply a tensor and a scalar pub fn mul_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Divide two tensors pub fn div(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Divide a tensor by a scalar pub fn div_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Calculate remainder of two tensors pub fn remainder( lhs: JitTensor, rhs: JitTensor, @@ -133,14 +149,17 @@ pub fn remainder( launch_binop::(lhs, rhs) } +/// Calculate the remainder of a tensor with a scalar pub fn remainder_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Calculate the power of two tensors pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::>(lhs, rhs) } +/// Bitwise and two tensors pub fn bitwise_and( lhs: JitTensor, rhs: JitTensor, @@ -148,10 +167,12 @@ pub fn bitwise_and( launch_binop_int::(lhs, rhs) } +/// Bitwise and with a scalar pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } +/// Bitwise or two tensors pub fn bitwise_or( lhs: JitTensor, rhs: JitTensor, @@ -159,10 +180,12 @@ pub fn bitwise_or( launch_binop_int::(lhs, rhs) } +/// Bitwise or with a scalar pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } +/// Bitwise xor two tensors pub fn bitwise_xor( lhs: JitTensor, rhs: JitTensor, @@ -170,6 +193,7 @@ pub fn bitwise_xor( launch_binop_int::(lhs, rhs) } +/// Bitwise xor with a scalar pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index b586c4a6b7..7b72073c06 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -23,7 +23,8 @@ pub struct JitTensor { pub device: R::Device, /// The strides of the tensor. pub strides: Vec, - pub(crate) dtype: DType, + /// The datatype of the tensor. + pub dtype: DType, } impl From> for TensorHandle { diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index a700d9e30f..093ecf5e8f 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -16,6 +16,7 @@ where { } +/// An int element for ndarray backend. pub trait IntNdArrayElement: NdArrayElement + Signed {} /// A general element for ndarray backend. @@ -34,13 +35,21 @@ pub trait NdArrayElement: /// A element for ndarray backend that supports exp ops. pub trait ExpElement { + /// Exponent fn exp_elem(self) -> Self; + /// Log fn log_elem(self) -> Self; + /// Log1p fn log1p_elem(self) -> Self; + /// Powf fn powf_elem(self, value: f32) -> Self; + /// Powi fn powi_elem(self, value: i32) -> Self; + /// Sqrt fn sqrt_elem(self) -> Self; + /// Abs fn abs_elem(self) -> Self; + /// Abs for int fn int_abs_elem(self) -> Self; } diff --git a/crates/burn-ndarray/src/lib.rs b/crates/burn-ndarray/src/lib.rs index 60c139bd25..95736b5efe 100644 --- a/crates/burn-ndarray/src/lib.rs +++ b/crates/burn-ndarray/src/lib.rs @@ -21,7 +21,7 @@ mod sharing; mod tensor; pub use backend::*; -pub use element::FloatNdArrayElement; +pub use element::*; pub(crate) use sharing::*; pub use tensor::*; diff --git a/crates/burn-ndarray/src/ops/conv.rs b/crates/burn-ndarray/src/ops/conv.rs index 429618826a..8f45b8f8f8 100644 --- a/crates/burn-ndarray/src/ops/conv.rs +++ b/crates/burn-ndarray/src/ops/conv.rs @@ -11,7 +11,7 @@ use ndarray::{ }; use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, + element::FloatNdArrayElement, ops::padding::{apply_padding_4d, apply_padding_5d}, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -98,7 +98,7 @@ fn conv3d_mad_inner( } } -pub(crate) fn conv2d( +pub(crate) fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -126,7 +126,7 @@ pub(crate) fn conv2d(x, options.padding, 0i32.elem()).array; + let x = apply_padding_4d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); @@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d( NdArrayTensor::new(output.into_dyn().into_shared()) } -pub(crate) fn conv3d( +pub(crate) fn conv3d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -345,7 +345,7 @@ pub(crate) fn conv3d(x, options.padding, 0i32.elem()).array; + let x = apply_padding_5d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 56b969a67c..a003e392f3 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -11,7 +11,7 @@ use ndarray::{ #[cfg(not(feature = "std"))] use num_traits::Float; -use crate::{element::QuantElement, FloatNdArrayElement, NdArrayTensor}; +use crate::{FloatNdArrayElement, NdArrayTensor}; use super::matmul::matmul; @@ -255,7 +255,6 @@ pub mod backward { #[cfg(target_has_atomic = "32")] use core::sync::atomic::Ordering; - use crate::element::IntNdArrayElement; use atomic_float::AtomicF32; use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; @@ -270,11 +269,7 @@ pub mod backward { ); /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. - pub(crate) fn deform_conv2d_backward< - F: FloatNdArrayElement, - I: IntNdArrayElement, - Q: QuantElement, - >( + pub(crate) fn deform_conv2d_backward( input: NdArrayTensor, offset: NdArrayTensor, weight: NdArrayTensor, diff --git a/crates/burn-ndarray/src/ops/maxpool.rs b/crates/burn-ndarray/src/ops/maxpool.rs index 90ffe30a95..b7f8e776e3 100644 --- a/crates/burn-ndarray/src/ops/maxpool.rs +++ b/crates/burn-ndarray/src/ops/maxpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement}, ops::padding::apply_padding_4d, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -9,7 +9,7 @@ use burn_common::{iter_range_par, run_par}; use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; -pub(crate) fn max_pool2d( +pub(crate) fn max_pool2d( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], @@ -30,7 +30,7 @@ pub(crate) fn max_pool2d(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); @@ -69,11 +69,7 @@ pub(crate) fn max_pool2d( +pub(crate) fn max_pool2d_with_indices( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], @@ -94,7 +90,7 @@ pub(crate) fn max_pool2d_with_indices< / stride_width) + 1; - let x = apply_padding_4d::(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index f0885e52e2..dbceac1934 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -46,11 +46,7 @@ impl ModuleOps, options: ConvOptions<2>, ) -> NdArrayTensorFloat { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::< - E, - I, - Q, - >( + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::( x, weight, bias, options ) .into()) @@ -89,7 +85,7 @@ impl ModuleOps( + let (x, offset, weight, mask, bias) = deform_conv2d_backward::( x, offset, weight, @@ -163,7 +159,7 @@ impl ModuleOps FloatTensor { - module_op!(inp(x), opt(), E, |x| max_pool2d::( + module_op!(inp(x), opt(), E, |x| max_pool2d::( x, kernel_size, stride, @@ -182,7 +178,7 @@ impl ModuleOps MaxPool2dWithIndices> { module_op!(inp(x), opt(), E, |x| { let (output, indices) = - max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); + max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); MaxPool2dWithIndices::new(output.into(), indices) }) } @@ -282,11 +278,7 @@ impl ModuleOps>, options: ConvOptions<3>, ) -> FloatTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::< - E, - I, - Q, - >( + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::( x, weight, bias, options ) .into()) diff --git a/crates/burn-ndarray/src/ops/padding.rs b/crates/burn-ndarray/src/ops/padding.rs index 99bcef5a3e..ccf3252205 100644 --- a/crates/burn-ndarray/src/ops/padding.rs +++ b/crates/burn-ndarray/src/ops/padding.rs @@ -1,13 +1,10 @@ -use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, - tensor::NdArrayTensor, -}; +use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor}; use burn_tensor::TensorMetadata; use ndarray::{Array4, Array5}; use super::NdArrayOps; -pub(crate) fn apply_padding_4d( +pub(crate) fn apply_padding_4d( x: NdArrayTensor, padding: [usize; 2], elem: E, @@ -37,7 +34,7 @@ pub(crate) fn apply_padding_4d( +pub(crate) fn apply_padding_5d( x: NdArrayTensor, padding: [usize; 3], elem: E, diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml new file mode 100644 index 0000000000..59ded6d1ea --- /dev/null +++ b/crates/burn-vision/Cargo.toml @@ -0,0 +1,50 @@ +[package] +authors = [ + "nathanielsimard ", + "wingertge ", +] +categories = ["science"] +description = "Vision processing operations for burn tensors" +documentation = "https://docs.rs/burn-vision" +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "gpu"] +license.workspace = true +name = "burn-vision" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision" +version.workspace = true + + +[features] +candle = ["burn-candle"] +default = ["ndarray", "jit-backend", "fusion"] +export-tests = ["burn-tensor-testgen"] +fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] +jit-backend = ["cubecl", "burn-jit"] +ndarray = ["burn-ndarray"] +tch = ["burn-tch"] + +# Test features +test-cpu = ["export-tests"] +test-cuda = ["jit-backend", "export-tests"] +test-vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"] +test-wgpu = ["jit-backend", "export-tests"] + +[dependencies] +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } +cubecl = { workspace = true, optional = true } +derive-new = { workspace = true } +ndarray = { workspace = true } +serde = { workspace = true } + +[dev-dependencies] +burn-cuda = { path = "../burn-cuda", version = "0.17.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } +cubecl = { workspace = true } diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs new file mode 100644 index 0000000000..3d78c08fbb --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -0,0 +1,233 @@ +use alloc::vec::Vec; +use burn_tensor::{ + backend::Backend, + ops::{BoolTensor, IntTensor}, + Bool, Int, Shape, Tensor, TensorData, +}; +use ndarray::{Array3, Axis}; + +use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity}; + +mod spaghetti; +mod spaghetti_4c; + +pub fn connected_components( + img: BoolTensor, + connectivity: Connectivity, +) -> IntTensor { + run::(img, connectivity, || NoOp).0 +} + +pub fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + _options: ConnectedStatsOptions, +) -> (IntTensor, ConnectedStatsPrimitive) { + let device = B::bool_device(&img); + let (labels, stats) = run::(img, connectivity, ConnectedStatsOp::default); + let stats = finalize_stats(&device, stats); + (labels, stats) +} + +fn run( + img: BoolTensor, + connectivity: Connectivity, + stats: impl Fn() -> Stats, +) -> (IntTensor, Vec) { + let device = B::bool_device(&img); + let img = Tensor::::from_primitive(img); + let [batches, height, width] = img.shape().dims(); + let img = img.into_data().convert::().to_vec::().unwrap(); + let img = Array3::from_shape_vec((batches, height, width), img).unwrap(); + let mut stats_res = Vec::with_capacity(batches); + + let process = match connectivity { + Connectivity::Four => spaghetti_4c::process::, + Connectivity::Eight => spaghetti::process::, + }; + + let mut stats_0 = stats(); + let mut out = process(img.index_axis(Axis(0), 0), &mut stats_0); + stats_res.push(stats_0); + for i in 1..batches { + let mut stats_i = stats(); + let batch = process(img.index_axis(Axis(0), i), &mut stats_i); + out.append(Axis(0), batch.view()).unwrap(); + stats_res.push(stats_i); + } + let (data, _) = out.into_raw_vec_and_offset(); + let data = TensorData::new(data, Shape::new([batches, height, width])); + let labels = Tensor::::from_data(data, &device).into_primitive(); + (labels, stats_res) +} + +pub trait Solver { + fn init(max_labels: usize) -> Self; + /// Hack to get around mutable borrow limitations on methods + fn merge(label_1: u32, label_2: u32, solver: &mut Self) -> u32; + fn new_label(&mut self) -> u32; + fn flatten(&mut self) -> u32; + fn get_label(&mut self, i_label: u32) -> u32; +} + +pub(crate) struct UnionFind { + labels: Vec, +} + +impl Solver for UnionFind { + fn init(max_labels: usize) -> Self { + let mut labels = Vec::with_capacity(max_labels); + labels.push(0); + Self { labels } + } + + fn merge(mut label_1: u32, mut label_2: u32, solver: &mut Self) -> u32 { + while solver.labels[label_1 as usize] < label_1 { + label_1 = solver.labels[label_1 as usize]; + } + + while solver.labels[label_2 as usize] < label_2 { + label_2 = solver.labels[label_2 as usize]; + } + + if label_1 < label_2 { + solver.labels[label_2 as usize] = label_1; + label_1 + } else { + solver.labels[label_1 as usize] = label_2; + label_2 + } + } + + fn new_label(&mut self) -> u32 { + let len = self.labels.len() as u32; + self.labels.push(len); + len + } + + fn flatten(&mut self) -> u32 { + let mut k = 1; + for i in 1..self.labels.len() { + if self.labels[i] < i as u32 { + self.labels[i] = self.labels[self.labels[i] as usize]; + } else { + self.labels[i] = k; + k += 1; + } + } + k + } + + fn get_label(&mut self, i_label: u32) -> u32 { + self.labels[i_label as usize] + } +} + +pub trait StatsOp { + fn init(&mut self, num_labels: u32); + fn update(&mut self, row: usize, column: usize, label: u32); + fn finish(&mut self); +} + +struct NoOp; + +impl StatsOp for NoOp { + fn init(&mut self, _num_labels: u32) {} + + fn update(&mut self, _row: usize, _column: usize, _label: u32) {} + + fn finish(&mut self) {} +} + +#[derive(Default, Debug)] +struct ConnectedStatsOp { + pub area: Vec, + pub left: Vec, + pub top: Vec, + pub right: Vec, + pub bottom: Vec, +} + +impl StatsOp for ConnectedStatsOp { + fn init(&mut self, num_labels: u32) { + let num_labels = num_labels as usize; + self.area = vec![0; num_labels]; + self.left = vec![u32::MAX; num_labels]; + self.top = vec![u32::MAX; num_labels]; + self.right = vec![0; num_labels]; + self.bottom = vec![0; num_labels]; + } + + fn update(&mut self, row: usize, column: usize, label: u32) { + let l = label as usize; + self.area[l] += 1; + self.left[l] = self.left[l].min(column as u32); + self.top[l] = self.top[l].min(row as u32); + self.right[l] = self.right[l].max(column as u32); + self.bottom[l] = self.bottom[l].max(row as u32); + } + + fn finish(&mut self) { + // Background shouldn't have stats + self.area[0] = 0; + self.left[0] = 0; + self.right[0] = 0; + self.top[0] = 0; + self.bottom[0] = 0; + } +} + +fn finalize_stats( + device: &B::Device, + stats: Vec, +) -> ConnectedStatsPrimitive { + let batches = stats.len(); + let max_len = stats.iter().map(|it| it.area.len()).max().unwrap_or(1); + let mut area = Vec::with_capacity(batches * max_len); + let mut left = Vec::with_capacity(batches * max_len); + let mut top = Vec::with_capacity(batches * max_len); + let mut right = Vec::with_capacity(batches * max_len); + let mut bottom = Vec::with_capacity(batches * max_len); + let mut max_label = Vec::with_capacity(batches); + + for mut stats in stats { + max_label.push(stats.area.len() as u32 - 1); + stats.area.resize(max_len, 0); + stats.left.resize(max_len, 0); + stats.top.resize(max_len, 0); + stats.right.resize(max_len, 0); + stats.bottom.resize(max_len, 0); + + area.extend(stats.area); + left.extend(stats.left); + top.extend(stats.top); + right.extend(stats.right); + bottom.extend(stats.bottom); + } + + let into_prim = |data: Vec| { + let data = TensorData::new(data, Shape::new([batches, max_len])); + Tensor::::from_data(data, device).into_primitive() + }; + + let max_label = { + let data = TensorData::new(max_label, Shape::new([batches])); + Tensor::::from_data(data, device).into_primitive() + }; + + ConnectedStatsPrimitive { + area: into_prim(area), + left: into_prim(left), + top: into_prim(top), + right: into_prim(right), + bottom: into_prim(bottom), + max_label, + } +} + +pub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize { + match conn { + Connectivity::Four => ((h * w + 1) / 2) + 1, + Connectivity::Eight => ((h + 1) / 2) * ((w + 1) / 2) + 1, + } +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs new file mode 100644 index 0000000000..60bb3c4ba7 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs @@ -0,0 +1,1954 @@ +no_analyze! {{ +use centerLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_1=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } + NODE_3=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_2); + } + else { + return Some(NODE_4); + } + } + NODE_4=> { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_2); + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_1); + } + } + NODE_2=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + return Some(NODE_5); + } + } + NODE_5=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + NODE_7=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + else { + return Some(NODE_8); + } + } + NODE_9=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + NODE_10=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_11); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_8); + } + else { + return Some(NODE_9); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_12); + } + else { + return Some(NODE_11); + } + } + } + } + NODE_8=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + return Some(NODE_9); + } + } + NODE_12=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_13); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_14); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_3); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_9); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_9); + } + } + } + } + } + NODE_6=> { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_15=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_5); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_9); + } + } + } + } + NODE_11=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + NODE_13=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_14=> { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_16=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_17); + } + else { + return Some(NODE_18); + } + } + NODE_18=> { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_19); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_17=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + } + NODE_20=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + NODE_21=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_22); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_18); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + NODE_23=> { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_24=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + NODE_25=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + NODE_26=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_19=> { + if img_row12[(c + 1) as usize] > 0 { + return Some(NODE_20); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_27=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_7); + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } + NODE_28=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_29=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + else { + return Some(NODE_4); + } + } + NODE_30=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_31); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_22=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + } + NODE_31=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_32=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_17); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + NODE_33=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_20); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_34=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_22); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_35=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_33); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_36); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_37); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + NODE_37=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + } + NODE_38=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_31); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_39=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + return Some(NODE_33); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_36); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_37); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + NODE_36=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_40=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + } +cl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_0); } else { return Some(cl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_15); + } + else { + return Some(NODE_1); + } +} +cl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_1); } else { return Some(cl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_12); + } + else { + return Some(NODE_1); + } +} +cl_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_2); } else { return Some(cl_break_1_2); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_10); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_7); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_3 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_3); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_23); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + else { + return Some(NODE_23); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_4 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_4); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_28); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_30); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_24); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_30); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_5 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_5); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_26); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + return Some(NODE_26); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_6); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + } + } + else { + return Some(NODE_4); + } + } + } +} +cl_tree_6 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_6); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_16); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_7 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_4); } else { return Some(cl_break_1_7); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_10); + } + else { + return Some(NODE_12); + } + } + else { + return Some(NODE_27); + } +} +cl_tree_8 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_8); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_28); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_38); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_24); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_38); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_9 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_5); } else { return Some(cl_break_1_9); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_8); + } + else { + return Some(NODE_11); + } + } + } + else { + return Some(NODE_15); + } + } + else { + return Some(NODE_27); + } +} +cl_tree_10 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_6); } else { return Some(cl_break_1_10); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_34); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_39); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_40); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + else { + return Some(NODE_15); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_32); + } + else { + return Some(NODE_39); + } + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_11 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_7); } else { return Some(cl_break_1_11); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + return Some(NODE_13); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + return Some(NODE_14); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_16); + } + else { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_17); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_19); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + } + else { + return Some(NODE_2); + } + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_12 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_8); } else { return Some(cl_break_1_12); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_34); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_35); + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_40); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + else { + return Some(NODE_15); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_32); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_35); + } + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } +} + NODE_41=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_42=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_43=> { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_44=> { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_45=> { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + } + else { + return Some(NODE_46); + } + } + NODE_47=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + NODE_48=> { + if img_row01[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_46=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } +cl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_47); + } + else { + return Some(NODE_43); + } + return None;} +cl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_41); + } + else { + return Some(NODE_43); + } + return None;} +cl_break_0_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_45); + } + else { + return Some(NODE_44); + } + return None;} +cl_break_0_3 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_44); + } + return None;} +cl_break_0_4 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_45); + } + else { + return Some(NODE_41); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_5 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_46); + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_6 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_42); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_7 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +cl_break_0_8 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_42); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} + NODE_49=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_50); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_51=> { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_52); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + } + else { + return Some(NODE_54); + } + } + NODE_53=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_55=> { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_52); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_50); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + } + else { + return Some(NODE_54); + } + } + NODE_56=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_57=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_58); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_58=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + NODE_59=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + NODE_60=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + return Some(NODE_61); + } + } + NODE_62=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_52=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_58); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_63=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_64); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + NODE_64=> { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_66); + } + else { + return Some(NODE_54); + } + } + NODE_50=> { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_61=> { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_67=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_65=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_54); + } + else { + return Some(NODE_61); + } + } + NODE_68=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_54); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + return Some(NODE_65); + } + } + NODE_66=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + return Some(NODE_67); + } + } + NODE_69=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + } + else { + return Some(NODE_67); + } + } + } + NODE_70=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_54=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_71=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_70); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } +cl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_54); + } + else { + return Some(NODE_68); + } + return None;} +cl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_71); + } + else { + return Some(NODE_68); + } + return None;} +cl_break_1_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_69); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_66); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_3 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_62); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_62); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_4 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_56); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_56); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_5 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_61); + } + } + } + return None;} +cl_break_1_6 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_57); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_7 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_69); + } + else { + return Some(NODE_71); + } + } + else { + return Some(NODE_63); + } + return None;} +cl_break_1_8 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_49); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_49); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_9 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_64); + } + else { + return Some(NODE_63); + } + return None;} +cl_break_1_10 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_51); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_51); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_11 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_70); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + return Some(NODE_54); + } + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_12 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_55); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_55); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + return None;} + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs new file mode 100644 index 0000000000..4cc475d836 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs @@ -0,0 +1,223 @@ +no_analyze!{{ +use firstLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_72=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_2); + } + } + NODE_73=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_2); + } + } + NODE_74=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } + } + } +fl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_73); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_73); + } + else { + return Some(NODE_74); + } + } +} +fl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_72); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_72); + } + else { + return Some(NODE_74); + } + } +} +fl_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_72); + } + else { + return Some(NODE_73); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_2); + } + } + } + else { + return Some(NODE_74); + } + } +} + NODE_75=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } +fl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +fl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +fl_break_0_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_75); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_75); + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} + NODE_76=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + } + NODE_77=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } +fl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + return Some(NODE_76); + } + } + return None;} +fl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_76); + } + } + return None;} +fl_break_1_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_77); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_77); + } + else { + return Some(NODE_77); + } + } + else { + return Some(NODE_76); + } + } + return None;} +fl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs new file mode 100644 index 0000000000..6c994fc7d3 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs @@ -0,0 +1,191 @@ +/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`. +macro_rules! no_analyze { + ($tokens:tt) => { + $tokens + }; +} + +pub(crate) use no_analyze; + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum centerLabels { + NODE_1, + NODE_2, + NODE_3, + NODE_4, + NODE_5, + NODE_6, + NODE_7, + NODE_8, + NODE_9, + NODE_10, + NODE_11, + NODE_12, + NODE_13, + NODE_14, + NODE_15, + NODE_16, + NODE_17, + NODE_18, + NODE_19, + NODE_20, + NODE_21, + NODE_22, + NODE_23, + NODE_24, + NODE_25, + NODE_26, + NODE_27, + NODE_28, + NODE_29, + NODE_30, + NODE_31, + NODE_32, + NODE_33, + NODE_34, + NODE_35, + NODE_36, + NODE_37, + NODE_38, + NODE_39, + NODE_40, + NODE_41, + NODE_42, + NODE_43, + NODE_44, + NODE_45, + NODE_46, + NODE_47, + NODE_48, + NODE_49, + NODE_50, + NODE_51, + NODE_52, + NODE_53, + NODE_54, + NODE_55, + NODE_56, + NODE_57, + NODE_58, + NODE_59, + NODE_60, + NODE_61, + NODE_62, + NODE_63, + NODE_64, + NODE_65, + NODE_66, + NODE_67, + NODE_68, + NODE_69, + NODE_70, + NODE_71, + cl_tree_0, + cl_tree_1, + cl_tree_2, + cl_tree_3, + cl_tree_4, + cl_tree_5, + cl_tree_6, + cl_tree_7, + cl_tree_8, + cl_tree_9, + cl_tree_10, + cl_tree_11, + cl_tree_12, + cl_break_0_0, + cl_break_0_1, + cl_break_0_2, + cl_break_0_3, + cl_break_0_4, + cl_break_0_5, + cl_break_0_6, + cl_break_0_7, + cl_break_0_8, + cl_break_1_0, + cl_break_1_1, + cl_break_1_2, + cl_break_1_3, + cl_break_1_4, + cl_break_1_5, + cl_break_1_6, + cl_break_1_7, + cl_break_1_8, + cl_break_1_9, + cl_break_1_10, + cl_break_1_11, + cl_break_1_12, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum firstLabels { + NODE_72, + NODE_73, + NODE_74, + NODE_75, + NODE_76, + NODE_77, + fl_tree_0, + fl_tree_1, + fl_tree_2, + fl_break_0_0, + fl_break_0_1, + fl_break_0_2, + fl_break_1_0, + fl_break_1_1, + fl_break_1_2, + fl_, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum lastLabels { + NODE_78, + NODE_79, + NODE_80, + NODE_81, + NODE_82, + NODE_83, + NODE_84, + NODE_85, + NODE_86, + NODE_87, + NODE_88, + NODE_89, + NODE_90, + NODE_91, + NODE_92, + ll_tree_0, + ll_tree_1, + ll_tree_2, + ll_tree_3, + ll_tree_4, + ll_tree_5, + ll_tree_6, + ll_tree_7, + ll_break_0_0, + ll_break_0_1, + ll_break_0_2, + ll_break_0_3, + ll_break_1_0, + ll_break_1_1, + ll_break_1_2, + ll_break_1_3, + ll_break_1_4, + ll_break_1_5, + ll_break_1_6, + ll_break_1_7, + ll_, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum singleLabels { + NODE_93, + NODE_94, + sl_tree_0, + sl_tree_1, + sl_break_0_0, + sl_break_0_1, + sl_break_1_0, + sl_break_1_1, + sl_, +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs new file mode 100644 index 0000000000..945c40f132 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs @@ -0,0 +1,787 @@ +no_analyze!{{ +use lastLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_78=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + NODE_79=> { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_80=> { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_81=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } + NODE_82=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + return Some(NODE_83); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } + NODE_84=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_6); + } + } + NODE_83=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + NODE_85=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_86=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } +ll_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_83); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_0); + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_84); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_79); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_0); + } + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_tree_3 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_78); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_85); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_tree_4 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_80); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } +} +ll_tree_5 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_86); + } + else { + return Some(NODE_82); + } +} +ll_tree_6 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_86); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + return Some(NODE_84); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + return Some(NODE_79); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_2); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_7 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_78); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_85); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_break_0_0 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_1 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_2 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_3 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} + NODE_87=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_88); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_88=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_89=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_90=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_91=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_92=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } +ll_break_1_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_88); + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_1 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_90); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_2 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_break_1_3 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_91); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_break_1_4 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +ll_break_1_5 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_92); + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_6 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_92); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_90); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_7 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_91); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs new file mode 100644 index 0000000000..e818d77e16 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs @@ -0,0 +1,91 @@ +no_analyze!{{ +use singleLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_93=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(sl_tree_0); + } + } +sl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_0); + } + } + else { + return Some(NODE_93); + } +} +sl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(sl_tree_0); + } + } + else { + return Some(NODE_93); + } +} +sl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +sl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} + NODE_94=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } +sl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + return Some(NODE_94); + } + return None;} +sl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_94); + } + return None;} +sl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs new file mode 100644 index 0000000000..66c79ae99c --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs @@ -0,0 +1,245 @@ +//! Spaghetti algorithm for connected component labeling +//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana, +//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling," +//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019. +//! +//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN) +//! as described in +//! +//! F. Bolelli, S. Allegretti, C. Grana. +//! "One DAG to Rule Them All." +//! IEEE Transactions on Pattern Analisys and Machine Intelligence, 2021 + +#![allow( + unreachable_code, + clippy::collapsible_else_if, + clippy::if_same_then_else +)] + +use ndarray::{s, Array2, ArrayView2, Axis}; + +#[allow(non_snake_case)] +mod Spaghetti_forest_labels; +pub(crate) use Spaghetti_forest_labels::*; + +use crate::Connectivity; + +use super::{max_labels, Solver, StatsOp}; + +pub fn process(img: ArrayView2, stats: &mut impl StatsOp) -> Array2 { + let (h, w) = img.dim(); + + let e_rows = h as u32 & 0xfffffffe; + let o_rows = h % 2 == 1; + let e_cols = w as u32 & 0xfffffffe; + let o_cols = w % 2 == 1; + + let mut img_labels = Array2::default(img.raw_dim()); + + let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight)); + + let solver = &mut solver; + + let w = w as i32; + + if h == 1 { + // Single line + let r = 0; + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + let mut c = -2i32; + let entry = singleLabels::sl_tree_0; + + include!("Spaghetti_single_line_forest_code.rs"); + } else { + // More than one line + + // First couple of lines + { + let img_row00 = img.index_axis(Axis(0), 0); + let img_row01 = img.index_axis(Axis(0), 1); + let mut img_labels_row00 = img_labels.slice_mut(s![0, ..]); + let mut c = -2i32; + let entry = firstLabels::fl_tree_0; + + include!("Spaghetti_first_line_forest_code.rs"); + } + + // Every other line but the last one if image has an odd number of rows + for r in (2..e_rows as usize).step_by(2) { + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row12 = img.index_axis(Axis(0), r - 2); + let img_row11 = img.index_axis(Axis(0), r - 1); + let img_row01 = img.index_axis(Axis(0), r + 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row12) = + img_labels.multi_slice_mut((s![r, ..], s![r - 2, ..])); + + let mut c = -2; + let entry = centerLabels::cl_tree_0; + + include!("Spaghetti_center_line_forest_code.rs"); + } + + if o_rows { + let r = h - 1; + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row12 = img.index_axis(Axis(0), r - 2); + let img_row11 = img.index_axis(Axis(0), r - 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row12) = + img_labels.multi_slice_mut((s![r, ..], s![r - 2, ..])); + let mut c = -2; + let entry = lastLabels::ll_tree_0; + + include!("Spaghetti_last_line_forest_code.rs"); + } + } + + let n_labels = solver.flatten(); + stats.init(n_labels); + + for r in (0..e_rows as usize).step_by(2) { + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row01 = img.index_axis(Axis(0), r + 1); + + // Row pointers for the output image + let (mut img_labels_row00, mut img_labels_row01) = + img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..])); + + for c in (0..e_cols as usize).step_by(2) { + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + stats.update(r, c, i_label); + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + } + if img_row00[c + 1] > 0 { + img_labels_row00[c + 1] = i_label; + stats.update(r, c + 1, i_label); + } else { + img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); + } + if img_row01[c] > 0 { + img_labels_row01[c] = i_label; + stats.update(r + 1, c, i_label); + } else { + img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); + } + if img_row01[c + 1] > 0 { + img_labels_row01[c + 1] = i_label; + stats.update(r + 1, c + 1, i_label); + } else { + img_labels_row01[c + 1] = 0; + stats.update(r + 1, c + 1, 0); + } + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); + img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); + img_labels_row01[c + 1] = 0; + stats.update(r + 1, c + 1, 0); + } + } + if o_cols { + let c = e_cols as usize; + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + stats.update(r, c, i_label); + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + } + if img_row01[c] > 0 { + img_labels_row01[c] = i_label; + stats.update(r + 1, c, i_label); + } else { + img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); + } + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); + } + } + } + + if o_rows { + let r = e_rows as usize; + + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + + for c in (0..e_cols as usize).step_by(2) { + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + stats.update(r, c, i_label); + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + } + if img_row00[c + 1] > 0 { + img_labels_row00[c + 1] = i_label; + stats.update(r, c + 1, i_label); + } else { + img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); + } + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); + } + } + if o_cols { + let c = e_cols as usize; + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + stats.update(r, c, i_label); + } else { + img_labels_row00[c] = 0; + stats.update(r, c, 0); + } + } else { + img_labels_row00[c] = 0; + stats.update(r, c, i_label); + } + } + } + + stats.finish(); + img_labels +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs new file mode 100644 index 0000000000..1c1b932334 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs @@ -0,0 +1,42 @@ +no_analyze!{{ +use centerLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { +cl_tree_0 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row11[(c) as usize]; + return Some(cl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_1); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_0); + } +} +cl_tree_1 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 1) as usize], img_labels_row11[(c) as usize], solver); + return Some(cl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 1) as usize]; + return Some(cl_tree_1); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_0); + } +} + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs new file mode 100644 index 0000000000..5deff2941b --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs @@ -0,0 +1,31 @@ +no_analyze!{{ +use firstLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { +fl_tree_0 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } +} +fl_tree_1 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 1) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } +} +fl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs new file mode 100644 index 0000000000..70e89e8ab1 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs @@ -0,0 +1,21 @@ +/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`. +macro_rules! no_analyze { + ($tokens:tt) => { + $tokens + }; +} + +pub(crate) use no_analyze; + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum centerLabels { + cl_tree_0, + cl_tree_1, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum firstLabels { + fl_tree_0, + fl_tree_1, + fl_, +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs new file mode 100644 index 0000000000..d1a9ab4304 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs @@ -0,0 +1,88 @@ +//! Spaghetti algorithm for connected component labeling, modified for 4-connectivity using the +//! 4-connected Rosenfeld mask. +//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana, +//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling," +//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019. +//! +//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN) +//! as described in +//! +//! F. Bolelli, S. Allegretti, C. Grana. +//! "One DAG to Rule Them All." +//! IEEE Transactions on Pattern Analisys and Machine Intelligence, 2021 + +#![allow(unreachable_code)] + +use ndarray::{s, Array2, ArrayView2, Axis}; + +use crate::Connectivity; + +use super::{max_labels, Solver, StatsOp}; + +#[allow(non_snake_case)] +mod Spaghetti4C_forest_labels; +pub(crate) use Spaghetti4C_forest_labels::*; + +pub fn process(img: ArrayView2, stats: &mut impl StatsOp) -> Array2 { + let (h, w) = img.dim(); + + let mut img_labels = Array2::default(img.raw_dim()); + + // A quick and dirty upper bound for the maximum number of labels. + // Following formula comes from the fact that a 2x2 block in 4-connectivity case + // can never have more than 2 new labels and 1 label for background. + // Worst case image example pattern: + // 1 0 1 0 1... + // 0 1 0 1 0... + // 1 0 1 0 1... + // ............ + let max_labels = max_labels(h, w, Connectivity::Four); + + let mut solver = LabelsSolver::init(max_labels); + let solver = &mut solver; + + let w = w as i32; + + // First row + { + let r = 0; + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + let mut c = -1i32; + + let entry = firstLabels::fl_tree_0; + + include!("Spaghetti4C_first_line_forest_code.rs"); + } + + for r in 1..h { + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row11 = img.index_axis(Axis(0), r - 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row11) = + img_labels.multi_slice_mut((s![r, ..], s![r - 1, ..])); + let mut c = -1i32; + + let entry = centerLabels::cl_tree_0; + + include!("Spaghetti4C_center_line_forest_code.rs"); + } + + let n_labels = solver.flatten(); + stats.init(n_labels); + + img_labels.indexed_iter_mut().for_each(|((r, c), label)| { + *label = solver.get_label(*label); + stats.update(r, c, *label); + }); + + stats.finish(); + img_labels +} diff --git a/crates/burn-vision/src/backends/cpu/mod.rs b/crates/burn-vision/src/backends/cpu/mod.rs new file mode 100644 index 0000000000..e64f7a8d75 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/mod.rs @@ -0,0 +1,4 @@ +mod connected_components; +mod ops; + +pub use connected_components::*; diff --git a/crates/burn-vision/src/backends/cpu/ops.rs b/crates/burn-vision/src/backends/cpu/ops.rs new file mode 100644 index 0000000000..46bc3a54f3 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/ops.rs @@ -0,0 +1,19 @@ +use crate::VisionOps; + +#[cfg(feature = "candle")] +use burn_candle::{Candle, FloatCandleElement, IntCandleElement}; +#[cfg(feature = "ndarray")] +use burn_ndarray::{FloatNdArrayElement, IntNdArrayElement, NdArray, QuantElement}; +#[cfg(feature = "tch")] +use burn_tch::{LibTorch, TchElement}; + +#[cfg(feature = "ndarray")] +impl VisionOps + for NdArray +{ +} + +#[cfg(feature = "candle")] +impl VisionOps for Candle {} +#[cfg(feature = "tch")] +impl VisionOps for LibTorch {} diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs new file mode 100644 index 0000000000..e4f89d25cd --- /dev/null +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -0,0 +1,622 @@ +//! Hardware Accelerated 4-connected, adapted from +//! A. Hennequin, L. Lacassagne, L. Cabaret, Q. Meunier, +//! "A new Direct Connected Component Labeling and Analysis Algorithms for GPUs", +//! DASIP, 2018 + +use crate::{ + backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions, + ConnectedStatsPrimitive, Connectivity, +}; +use burn_jit::{ + kernel, + ops::{into_data_sync, numeric::zeros_device}, + tensor::JitTensor, + BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, +}; +use burn_tensor::{ops::IntTensorOps, Shape}; +use cubecl::{prelude::*, Feature}; + +use super::prefix_sum::prefix_sum; + +const BLOCK_H: u32 = 4; + +#[cube] +fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { + let mut label_1 = label_1; + let mut label_2 = label_2; + + while label_1 != label_2 && (label_1 != u32::cast_from(Atomic::load(&labels[label_1])) - 1) { + label_1 = u32::cast_from(Atomic::load(&labels[label_1])) - 1; + } + while label_1 != label_2 && (label_2 != u32::cast_from(Atomic::load(&labels[label_2])) - 1) { + label_2 = u32::cast_from(Atomic::load(&labels[label_2])) - 1; + } + while label_1 != label_2 { + #[allow(clippy::manual_swap)] + if label_1 < label_2 { + let tmp = label_1; + label_1 = label_2; + label_2 = tmp; + } + let label_3 = u32::cast_from(Atomic::min(&labels[label_1], I::cast_from(label_2 + 1))) - 1; + if label_1 == label_3 { + label_1 = label_2; + } else { + label_1 = label_3; + } + } +} + +#[cube] +fn start_distance(pixels: u32, tx: u32) -> u32 { + u32::leading_zeros(u32::bitwise_not(pixels << (32 - tx))) +} + +#[cube] +fn end_distance(pixels: u32, tx: u32) -> u32 { + u32::find_first_set(u32::bitwise_not(pixels >> (tx + 1))) +} + +#[cube] +#[expect(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")] +fn ballot_dyn(y: u32, pred: bool) -> u32 { + let index = y % (PLANE_DIM / 32); + plane_ballot(pred)[index] +} + +#[cube(launch_unchecked)] +fn strip_labeling( + img: &Tensor, + labels: &Tensor>, + #[comptime] connectivity: Connectivity, +) { + let mut shared_pixels = SharedMemory::::new(BLOCK_H); + + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y; + let rows = labels.shape(1); + let cols = labels.shape(2); + + if y >= rows { + terminate!(); + } + + let img_stride = img.stride(1); + let labels_stride = labels.stride(1); + + let img_line_base = batch * img.stride(0) + y * img_stride + UNIT_POS_X; + let labels_line_base = batch * labels.stride(0) + y * labels.stride(1) + UNIT_POS_X; + + let mut distance_y = 0; + let mut distance_y_1 = 0; + + for i in range_stepped(0, img.shape(2), PLANE_DIM) { + let x = UNIT_POS_X + i; + + if x < cols { + let mut mask = 0xffffffffu32; + let involved_cols = cols - i; + if involved_cols < 32 { + mask >>= 32 - involved_cols; + } + + let img_index = img_line_base + i; + let labels_index = labels_line_base + i; + + let p_y = bool::cast_from(img[img_index]); + + let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask; + let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X); + + if p_y && s_dist_y == 0 { + Atomic::store( + &labels[labels_index], + I::cast_from(labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1), + ); + } + + // Only needed pre-Volta, but we can't check that at present + sync_units(); + + if UNIT_POS_X == 0 { + shared_pixels[UNIT_POS_Y] = pixels_y; + } + + sync_units(); + + // Requires if and not select, because `select` may execute the then branch even if the + // condition is false (on non-CUDA backends), which can lead to OOB reads. + let pixels_y_1 = if UNIT_POS_Y > 0 { + shared_pixels[UNIT_POS_Y - 1] + } else { + 0u32 + }; + + let p_y_1 = (pixels_y_1 >> UNIT_POS_X) & 1 != 0; + let mut s_dist_y_1 = start_distance(pixels_y_1, UNIT_POS_X); + + if UNIT_POS_X == 0 { + s_dist_y = distance_y; + s_dist_y_1 = distance_y_1; + } + + match connectivity { + Connectivity::Four => { + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } + } + Connectivity::Eight => { + let pixels_y_shifted = (pixels_y << 1) | (distance_y > 0) as u32; + let pixels_y_1_shifted = (pixels_y_1 << 1) | (distance_y_1 > 0) as u32; + + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } else if p_y && s_dist_y == 0 && (pixels_y_1_shifted >> UNIT_POS_X) & 1 != 0 { + let s_dist_y_1_prev = select( + UNIT_POS_X == 0, + distance_y_1 - 1, + start_distance(pixels_y_1, UNIT_POS_X - 1), + ); + let label_1 = labels_index; + let label_2 = labels_index - labels_stride - 1 - s_dist_y_1_prev; + merge(labels, label_1, label_2); + } else if p_y_1 && s_dist_y_1 == 0 && (pixels_y_shifted >> UNIT_POS_X) & 1 != 0 + { + let s_dist_y_prev = select( + UNIT_POS_X == 0, + distance_y - 1, + start_distance(pixels_y, UNIT_POS_X - 1), + ); + let label_1 = labels_index - 1 - s_dist_y_prev; + let label_2 = labels_index - labels_stride; + merge(labels, label_1, label_2); + } + } + } + + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } + + let mut d = start_distance(pixels_y_1, 32); + distance_y_1 = d + select(d == 32, distance_y_1, 0); + d = start_distance(pixels_y, 32); + distance_y = d + select(d == 32, distance_y, 0); + } + } +} + +#[cube(launch_unchecked)] +fn strip_merge( + img: &Tensor, + labels: &Tensor>, + #[comptime] connectivity: Connectivity, +) { + let batch = CUBE_POS_Z; + let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM; + let y = (CUBE_POS_Y + 1) * BLOCK_H; + let x = plane_start_x + UNIT_POS_X; + + let img_step = img.stride(1); + let labels_step = labels.stride(1); + let cols = img.shape(2); + + if y < labels.shape(1) && x < labels.shape(2) { + let mut mask = 0xffffffffu32; + if cols - plane_start_x < 32 { + mask >>= 32 - (cols - plane_start_x); + } + + let img_index = batch * img.stride(0) + y * img_step + x; + let labels_index = batch * labels.stride(0) + y * labels_step + x; + + let img_index_up = img_index - img_step; + let labels_index_up = labels_index - labels_step; + + let p = bool::cast_from(img[img_index]); + let p_up = bool::cast_from(img[img_index_up]); + + let pixels = ballot_dyn(UNIT_POS_Z, p) & mask; + let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask; + + match connectivity { + Connectivity::Four => { + if p && p_up { + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + if s_dist == 0 || s_dist_up == 0 { + merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + } + } + } + Connectivity::Eight => { + let mut last_dist_vec = SharedMemory::::new(32); + let mut last_dist_up_vec = SharedMemory::::new(32); + + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + + if UNIT_POS_PLANE == PLANE_DIM - 1 { + last_dist_vec[UNIT_POS_Z] = start_distance(pixels, 32); + last_dist_up_vec[UNIT_POS_Z] = start_distance(pixels_up, 32); + } + + sync_units(); + + if CUBE_POS_X == 0 || UNIT_POS_Z > 0 { + let last_dist = if UNIT_POS_Z > 0 { + last_dist_vec[UNIT_POS_Z - 1] + } else { + 0u32 + }; + let last_dist_up = if UNIT_POS_Z > 0 { + last_dist_up_vec[UNIT_POS_Z - 1] + } else { + 0u32 + }; + + let p_prev = + select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0; + let p_up_prev = select( + UNIT_POS_X > 0, + (pixels_up >> (UNIT_POS_X - 1)) & 1, + last_dist_up, + ) != 0; + + if p && p_up { + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + if s_dist == 0 || s_dist_up == 0 { + merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + } + } else if p && p_up_prev && s_dist == 0 { + let s_dist_up_prev = select( + UNIT_POS_X == 0, + last_dist_up - 1, + start_distance(pixels_up, UNIT_POS_X - 1), + ); + merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev); + } else if p_prev && p_up && s_dist_up == 0 { + let s_dist_prev = select( + UNIT_POS_X == 0, + last_dist - 1, + start_distance(pixels, UNIT_POS_X - 1), + ); + merge(labels, labels_index - 1 - s_dist_prev, labels_index_up); + } + } + } + } + } +} + +#[cube(launch_unchecked)] +fn relabeling(img: &Tensor, labels: &mut Tensor) { + let batch = ABSOLUTE_POS_Z; + let plane_start_x = CUBE_POS_X * CUBE_DIM_X; + let y = ABSOLUTE_POS_Y; + let x = plane_start_x + UNIT_POS_X; + + let cols = labels.shape(2); + let rows = labels.shape(1); + let img_step = img.stride(1); + let labels_step = labels.stride(1); + + if x < cols && y < rows { + let mut mask = 0xffffffffu32; + if cols - plane_start_x < 32 { + mask >>= 32 - (cols - plane_start_x); + } + + let img_index = batch * img.stride(0) + y * img_step + x; + let labels_index = batch * labels.stride(0) + y * labels_step + x; + + let p = bool::cast_from(img[img_index]); + let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; + let s_dist = start_distance(pixels, UNIT_POS_X); + let mut label = 0u32; + + if p && s_dist == 0 { + label = u32::cast_from(labels[labels_index]) - 1; + while label != u32::cast_from(labels[label]) - 1 { + label = u32::cast_from(labels[label]) - 1; + } + } + + label = plane_broadcast(label, UNIT_POS_X - s_dist); + + if p { + labels[labels_index] = I::cast_from(label + 1); + } + } +} + +#[cube(launch_unchecked)] +fn analysis( + img: &Tensor, + labels: &mut Tensor, + area: &mut Tensor>, + top: &mut Tensor>, + left: &mut Tensor>, + right: &mut Tensor>, + bottom: &mut Tensor>, + max_label: &mut Tensor>, + #[comptime] opts: ConnectedStatsOptions, +) { + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y; + let x = ABSOLUTE_POS_X; + + let cols = labels.shape(2); + let rows = labels.shape(1); + let img_step = img.stride(1); + let labels_step = labels.stride(1); + let b_offs = batch * labels.stride(0); + + if x < cols && y < rows { + let mut mask = 0xffffffffu32; + if cols - CUBE_POS_X * CUBE_DIM_X < 32 { + mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + } + + let img_index = b_offs + y * img_step + x; + let labels_index = b_offs + y * labels_step + x; + + let p = bool::cast_from(img[img_index]); + let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; + let s_dist = start_distance(pixels, UNIT_POS_X); + let count = end_distance(pixels, UNIT_POS_X); + let max_x = x + count - 1; + + let mut label = 0u32; + + if p && s_dist == 0 { + label = u32::cast_from(labels[labels_index]) - 1; + while label != u32::cast_from(labels[b_offs + label]) - 1 { + label = u32::cast_from(labels[b_offs + label]) - 1; + } + label += 1; + + Atomic::add(&area[b_offs + label], I::cast_from(count)); + + if opts.bounds_enabled { + Atomic::min(&left[b_offs + label], I::cast_from(x)); + Atomic::min(&top[b_offs + label], I::cast_from(y)); + Atomic::max(&right[b_offs + label], I::cast_from(max_x)); + Atomic::max(&bottom[b_offs + label], I::cast_from(y)); + } + if comptime!(opts.max_label_enabled || opts.compact_labels) { + Atomic::max(&max_label[batch], I::cast_from(label)); + } + } + + label = plane_broadcast(label, UNIT_POS_X - s_dist); + + if p { + labels[labels_index] = I::cast_from(label); + } + } +} + +#[cube(launch_unchecked)] +fn compact_labels( + labels: &mut Tensor, + remap: &Tensor, + max_label: &Tensor>, +) { + let batch = ABSOLUTE_POS_Z; + let x = ABSOLUTE_POS_X; + let y = ABSOLUTE_POS_Y; + + let labels_pos = batch * labels.stride(0) + y * labels.stride(1) + x * labels.stride(2); + + if labels_pos >= labels.len() { + terminate!(); + } + + let label = u32::cast_from(labels[labels_pos]); + if label != 0 { + let new_label = remap[label]; + labels[labels_pos] = new_label; + Atomic::max(&max_label[batch], new_label); + } +} + +#[cube(launch_unchecked)] +fn compact_stats( + area: &Tensor, + area_new: &mut Tensor, + top: &Tensor, + top_new: &mut Tensor, + left: &Tensor, + left_new: &mut Tensor, + right: &Tensor, + right_new: &mut Tensor, + bottom: &Tensor, + bottom_new: &mut Tensor, + remap: &Tensor, +) { + let label = ABSOLUTE_POS_X; + if label >= remap.len() { + terminate!(); + } + + let area = area[label]; + if area == I::new(0) { + terminate!(); + } + let new_label = u32::cast_from(remap[label]); + + area_new[new_label] = area; + // This should be gated but there's a problem with the Eq bound only being implemented for tuples + // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work. + top_new[new_label] = top[label]; + left_new[new_label] = left[label]; + right_new[new_label] = right[label]; + bottom_new[new_label] = bottom[label]; +} + +#[allow(clippy::type_complexity)] +pub fn hardware_accelerated( + img: JitTensor, + stats_opt: ConnectedStatsOptions, + connectivity: Connectivity, +) -> Result< + ( + JitTensor, + ConnectedStatsPrimitive>, + ), + String, +> { + let client = img.client.clone(); + let device = img.device.clone(); + + if !client.properties().feature_enabled(Feature::Plane) { + return Err("Requires plane instructions".into()); + } + + let props = client.properties().hardware_properties(); + + if props.plane_size_min < 32 { + return Err("Requires plane size of at least 32".into()); + } + + let [batches, rows, cols] = img.shape.dims(); + + let labels = zeros_device::(client.clone(), device.clone(), img.shape.clone()); + + // Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32. + // This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp + // size at compile time. `REQUIRE_FULL_SUBGROUPS` or subgroup size controls are not supported + // in wgpu. + let warp_size = 32; + let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H); + let cube_count = CubeCount::Static(1, (rows as u32).div_ceil(cube_dim.y), batches as u32); + + unsafe { + strip_labeling::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + connectivity, + ) + }; + + let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32); + let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps); + let cube_count = CubeCount::Static( + Ord::max((cols as u32 + warp_size * 30 - 1) / (warp_size * 31), 1), + (rows as u32 - 1) / BLOCK_H, + batches as u32, + ); + + unsafe { + strip_merge::launch_unchecked::( + &client, + cube_count, + cube_dim_merge, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + connectivity, + ) + }; + + let cube_count = CubeCount::Static( + (cols as u32).div_ceil(cube_dim.x), + (rows as u32).div_ceil(cube_dim.y), + batches as u32, + ); + + let mut stats = stats_from_opts(labels.clone(), stats_opt); + + if stats_opt == ConnectedStatsOptions::none() { + unsafe { + relabeling::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + ) + }; + } else { + unsafe { + analysis::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), + stats_opt, + ) + }; + if stats_opt.compact_labels { + let max_label = JitBackend::::int_max(stats.max_label); + let max_label = into_data_sync::(max_label).convert::(); + let max_label = max_label.as_slice::().unwrap()[0] as usize; + let sliced = kernel::slice::( + stats.area.clone(), + &[0..batches, 0..(max_label + 1).next_multiple_of(4)], + ); + let relabel = prefix_sum::(sliced); + + let cube_dim = CubeDim::default(); + let cube_count = CubeCount::new_3d( + (cols as u32).div_ceil(cube_dim.x), + (rows as u32).div_ceil(cube_dim.y), + batches as u32, + ); + stats.max_label = + zeros_device::(client.clone(), device.clone(), Shape::new([batches])); + unsafe { + compact_labels::launch_unchecked::( + &client, + cube_count, + cube_dim, + labels.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), + ) + }; + + let cube_dim = CubeDim::new_1d(256); + let cube_count = + CubeCount::new_3d((rows * cols).div_ceil(256) as u32, 1, batches as u32); + unsafe { + compact_stats::launch_unchecked::( + &client, + cube_count, + cube_dim, + stats.area.copy().as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.copy().as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.copy().as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.copy().as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.copy().as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + ) + }; + } + } + + Ok((labels, stats)) +} diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/jit/connected_components/mod.rs new file mode 100644 index 0000000000..53627a077e --- /dev/null +++ b/crates/burn-vision/src/backends/jit/connected_components/mod.rs @@ -0,0 +1,51 @@ +mod hardware_accelerated; + +/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops +/// to really use it in a general case. Needs more work to use as a normal tensor method. +mod prefix_sum; + +use burn_jit::{ + ops::numeric::{full_device, zeros_device}, + tensor::JitTensor, + BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, +}; +use burn_tensor::Shape; +pub use hardware_accelerated::*; + +use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive}; + +pub(crate) fn stats_from_opts( + l: JitTensor, + opts: ConnectedStatsOptions, +) -> ConnectedStatsPrimitive> +where + R: JitRuntime, + F: FloatElement, + I: IntElement, + BT: BoolElement, +{ + let [batches, height, width] = l.shape.dims(); + let shape = Shape::new([batches, height * width]); + let zeros = || zeros_device::(l.client.clone(), l.device.clone(), shape.clone()); + let max = I::max_value(); + let max = || full_device::(l.client.clone(), shape.clone(), l.device.clone(), max); + let dummy = || { + JitTensor::new_contiguous( + l.client.clone(), + l.device.clone(), + shape.clone(), + l.handle.clone(), + l.dtype, + ) + }; + ConnectedStatsPrimitive { + area: (opts != ConnectedStatsOptions::none()) + .then(zeros) + .unwrap_or_else(dummy), + left: opts.bounds_enabled.then(max).unwrap_or_else(dummy), + top: opts.bounds_enabled.then(max).unwrap_or_else(dummy), + right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy), + bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy), + max_label: zeros_device::(l.client.clone(), l.device.clone(), Shape::new([1])), + } +} diff --git a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs new file mode 100644 index 0000000000..f22910f442 --- /dev/null +++ b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs @@ -0,0 +1,256 @@ +use burn_tensor::Shape; +use cubecl::prelude::*; + +use burn_jit::{ + ops::{ + numeric::{empty_device, zeros_device}, + reshape, + }, + tensor::JitTensor, + IntElement, JitRuntime, +}; + +const CUBE_SIZE: u32 = 256; +const MIN_SUBGROUP_SIZE: u32 = 4; +const MAX_REDUCE_SIZE: u32 = CUBE_SIZE / MIN_SUBGROUP_SIZE; + +const PART_SIZE: u32 = 4096; + +#[cube(launch_unchecked)] +fn prefix_sum_kernel( + scan_in: &Tensor>, + scan_out: &mut Tensor>, + scan_bump: &Tensor>, + reduction: &Tensor>, + cube_count_x: u32, +) { + let mut broadcast = SharedMemory::::new(1); + let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); + let batch = CUBE_POS_Z; + let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); + let nums_per_cube = CUBE_SIZE * line_spt; + let v_last = comptime!(scan_in.line_size() - 1); + + //acquire partition index + if UNIT_POS_X == 0 { + broadcast[0] = Atomic::add(&scan_bump[batch], I::new(1)); + } + sync_units(); + let part_id = u32::cast_from(broadcast[0]); + + let plane_id = UNIT_POS_X / PLANE_DIM; + let dev_offs = part_id * nums_per_cube; + let plane_offs = plane_id * PLANE_DIM * line_spt; + + // Exit if full plane is out of bounds + if dev_offs + plane_offs >= scan_in.shape(1) { + terminate!(); + } + + let zero = I::new(0); + + let flag_reduction = I::new(1); + let flag_inclusive = I::new(2); + let flag_mask = I::new(3); + + let red_offs = batch * reduction.stride(0); + let scan_offs = batch * scan_in.stride(0); + + let mut t_scan = Array::>::vectorized(line_spt, scan_in.line_size()); + { + let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; + + if part_id < cube_count_x - 1 { + for k in 0..line_spt { + // Manually fuse not_equal and cast + let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } + t_scan[k] = scan; + i += PLANE_DIM; + } + } + + if part_id == cube_count_x - 1 { + for k in 0..line_spt { + if i < scan_in.shape(1) { + // Manually fuse not_equal and cast + let mut scan = + Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } + t_scan[k] = scan; + } + i += PLANE_DIM; + } + } + + let mut prev = zero; + let plane_mask = PLANE_DIM - 1; + let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; + for k in 0..line_spt { + let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); + t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); + prev += plane_broadcast(t, 0); + } + + if UNIT_POS_PLANE == 0 { + reduce[plane_id] = prev; + } + } + sync_units(); + + //Non-divergent subgroup agnostic inclusive scan across subgroup reductions + let lane_log = count_trailing_zeros(PLANE_DIM); + let spine_size = CUBE_DIM >> lane_log; + { + let mut offset_0 = 0; + let mut offset_1 = 0; + let aligned_size = + 1 << ((count_trailing_zeros(spine_size) + lane_log + 1) / lane_log * lane_log); + let mut j = PLANE_DIM; + while j <= aligned_size { + let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0; + let pred_0 = i_0 < spine_size; + let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0], zero)); + if pred_0 { + reduce[i_0] = t_0; + } + sync_units(); + + if j != PLANE_DIM { + let rshift = j >> lane_log; + let i_1 = UNIT_POS_X + rshift; + if (i_1 & (j - 1)) >= rshift { + let pred_1 = i_1 < spine_size; + let t_1 = select(pred_1, reduce[((i_1 >> offset_1) << offset_1) - 1], zero); + if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 { + reduce[i_1] += t_1; + } + } + } else { + offset_0 += 1; + } + offset_1 += lane_log; + + j <<= lane_log; + } + } + sync_units(); + + //Device broadcast + if UNIT_POS_X == 0 { + Atomic::store( + &reduction[part_id + red_offs], + (reduce[spine_size - 1] << I::new(2)) + | select(part_id != 0, flag_reduction, flag_inclusive), + ) + } + + //Lookback, single thread + if part_id != 0 { + if UNIT_POS_X == 0 { + let mut lookback_id = part_id - 1; + let mut prev_reduction = zero; + loop { + let flag_payload = Atomic::load(&reduction[lookback_id + red_offs]); + if (flag_payload & flag_mask) == flag_inclusive { + prev_reduction += flag_payload >> I::new(2); + Atomic::store( + &reduction[part_id + red_offs], + ((prev_reduction + reduce[spine_size - 1]) << I::new(2)) | flag_inclusive, + ); + broadcast[0] = prev_reduction; + break; + } + + if (flag_payload & flag_mask) == flag_reduction { + prev_reduction += flag_payload >> I::new(2); + lookback_id -= 1; + } + } + } + sync_units(); + } + + { + let prev = if plane_id != 0 { + reduce[plane_id - 1] + } else { + zero + }; + let prev = Line::cast_from(broadcast[0] + prev); + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt; + let dev_offset = part_id * nums_per_cube; + let mut i = s_offset + dev_offset; + + if part_id < cube_count_x - 1 { + for k in 0..line_spt { + scan_out[i + scan_offs] = t_scan[k] + prev; + i += PLANE_DIM; + } + } + + if part_id == cube_count_x - 1 { + for k in 0..line_spt { + if i < scan_out.shape(1) { + scan_out[i + scan_offs] = t_scan[k] + prev; + } + i += PLANE_DIM; + } + } + } +} + +#[cube] +fn count_trailing_zeros(num: u32) -> u32 { + u32::find_first_set(num) - 1 +} + +/// Compute the prefix sum of a tensor +pub fn prefix_sum(input: JitTensor) -> JitTensor { + let client = input.client.clone(); + let device = input.device.clone(); + let num_elems = input.shape.num_elements() as u32; + let numbers = *input.shape.dims.last().unwrap() as u32; + let batches = num_elems / numbers; + + let input = reshape(input, Shape::new([batches as usize, numbers as usize])); + let out = empty_device::(client.clone(), device.clone(), input.shape.clone()); + + let cubes = numbers.div_ceil(PART_SIZE); + let cube_dim = CubeDim::new_1d(CUBE_SIZE); + let cube_count = CubeCount::new_3d(cubes, 1, batches); + + let bump = zeros_device::( + client.clone(), + device.clone(), + Shape::new([batches as usize]), + ); + let reduction = zeros_device::( + client.clone(), + device.clone(), + Shape::new([batches as usize, cubes as usize]), + ); + + unsafe { + prefix_sum_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(4), + out.as_tensor_arg::(4), + bump.as_tensor_arg::(1), + reduction.as_tensor_arg::(1), + ScalarArg::new(cubes), + ) + }; + + out +} diff --git a/crates/burn-vision/src/backends/jit/mod.rs b/crates/burn-vision/src/backends/jit/mod.rs new file mode 100644 index 0000000000..9d610df49a --- /dev/null +++ b/crates/burn-vision/src/backends/jit/mod.rs @@ -0,0 +1,2 @@ +mod connected_components; +mod ops; diff --git a/crates/burn-vision/src/backends/jit/ops.rs b/crates/burn-vision/src/backends/jit/ops.rs new file mode 100644 index 0000000000..8d75b37f8d --- /dev/null +++ b/crates/burn-vision/src/backends/jit/ops.rs @@ -0,0 +1,167 @@ +use crate::{ + backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, +}; +#[cfg(feature = "fusion")] +use burn_fusion::{client::FusionClient, stream::Operation, Fusion, FusionBackend, FusionRuntime}; +use burn_jit::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_tensor::ops::{BoolTensor, IntTensor}; +#[cfg(feature = "fusion")] +use burn_tensor::{ + repr::{CustomOpDescription, HandleContainer, OperationDescription}, + Element, +}; + +use super::connected_components::hardware_accelerated; + +impl VisionOps for JitBackend +where + R: JitRuntime, + F: FloatElement, + I: IntElement, + BT: BoolElement, +{ + fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { + hardware_accelerated::( + img.clone(), + ConnectedStatsOptions::none(), + connectivity, + ) + .map(|it| it.0) + .unwrap_or_else(|_| cpu::connected_components::(img, connectivity)) + } + + fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + hardware_accelerated::(img.clone(), opts, connectivity).unwrap_or_else(|_| { + cpu::connected_components_with_stats::(img, connectivity, opts) + }) + } +} + +#[cfg(feature = "fusion")] +impl> VisionOps for Fusion { + fn connected_components(img: BoolTensor, conn: Connectivity) -> IntTensor { + let batches = img.shape[0]; + let height = img.shape[1]; + let width = img.shape[2]; + let client = img.client.clone(); + + #[derive(derive_new::new)] + struct ConnComp { + desc: CustomOpDescription, + conn: Connectivity, + _b: core::marker::PhantomData, + } + + impl> Operation for ConnComp { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let ([img], [labels]) = self.desc.consume(); + let input = handles.get_bool_tensor::(&img); + let output = B1::connected_components(input, self.conn); + + handles.register_int_tensor::(&labels.id, output); + } + } + + let stream = img.stream; + let out = client.tensor_uninitialized(vec![batches, height, width], B::IntElem::dtype()); + + let desc = CustomOpDescription::new( + "connected_components", + &[img.into_description()], + &[out.to_description_out()], + ); + client.register( + vec![stream], + OperationDescription::Custom(desc.clone()), + ConnComp::::new(desc, conn), + ); + + out + } + + fn connected_components_with_stats( + img: BoolTensor, + conn: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + let batches = img.shape[0]; + let height = img.shape[1]; + let width = img.shape[2]; + let client = img.client.clone(); + + #[derive(derive_new::new)] + struct ConnCompStats { + desc: CustomOpDescription, + conn: Connectivity, + opts: ConnectedStatsOptions, + _b: core::marker::PhantomData, + } + + impl> Operation for ConnCompStats { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let ([img], [labels, area, left, top, right, bottom, max_label]) = + self.desc.consume(); + let input = handles.get_bool_tensor::(&img); + let (output, stats) = + B1::connected_components_with_stats(input, self.conn, self.opts); + + handles.register_int_tensor::(&labels.id, output); + handles.register_int_tensor::(&area.id, stats.area); + handles.register_int_tensor::(&left.id, stats.left); + handles.register_int_tensor::(&top.id, stats.top); + handles.register_int_tensor::(&right.id, stats.right); + handles.register_int_tensor::(&bottom.id, stats.bottom); + handles.register_int_tensor::(&max_label.id, stats.max_label); + } + } + + let stream = img.stream; + let out = client.tensor_uninitialized(vec![batches, height, width], B::IntElem::dtype()); + let area = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let left = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let top = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let right = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let bottom = + client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let max_label = client.tensor_uninitialized(vec![batches], B::IntElem::dtype()); + + let desc = CustomOpDescription::new( + "connected_components", + &[img.into_description()], + &[ + out.to_description_out(), + area.to_description_out(), + left.to_description_out(), + top.to_description_out(), + right.to_description_out(), + bottom.to_description_out(), + max_label.to_description_out(), + ], + ); + client.register( + vec![stream], + OperationDescription::Custom(desc.clone()), + ConnCompStats::::new(desc, conn, opts), + ); + + let stats = ConnectedStatsPrimitive { + area, + left, + top, + right, + bottom, + max_label, + }; + (out, stats) + } +} diff --git a/crates/burn-vision/src/backends/mod.rs b/crates/burn-vision/src/backends/mod.rs new file mode 100644 index 0000000000..6886bb4907 --- /dev/null +++ b/crates/burn-vision/src/backends/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod cpu; +#[cfg(feature = "jit-backend")] +mod jit; diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs new file mode 100644 index 0000000000..a384cde51f --- /dev/null +++ b/crates/burn-vision/src/lib.rs @@ -0,0 +1,25 @@ +//! Vision ops for burn, with GPU acceleration where possible. +//! +//! # Operations +//! Operation names are based on `opencv` wherever applicable. +//! +//! Currently implemented are: +//! - `connected_components` +//! - `connected_components_with_stats` +//! + +#![warn(missing_docs)] + +extern crate alloc; + +/// Backend implementations for JIT and CPU +pub mod backends; +mod ops; +mod tensor; + +#[cfg(feature = "export-tests")] +#[allow(missing_docs)] +mod tests; + +pub use ops::*; +pub use tensor::*; diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs new file mode 100644 index 0000000000..f41f777405 --- /dev/null +++ b/crates/burn-vision/src/ops/base.rs @@ -0,0 +1,127 @@ +use crate::backends::cpu; +use burn_tensor::{ + backend::Backend, + ops::{BoolTensor, IntTensor}, + Int, Tensor, +}; + +/// Connected components connectivity +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Connectivity { + /// Four-connected (only connected in cardinal directions) + Four, + /// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground) + Eight, +} + +/// Which stats should be enabled for `connected_components_with_stats`. +/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats. +/// +/// Disabled stats are aliased to the labels tensor +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct ConnectedStatsOptions { + /// Whether to enable bounding boxes + pub bounds_enabled: bool, + /// Whether to enable the max label + pub max_label_enabled: bool, + /// Whether labels must be contiguous starting at 1 + pub compact_labels: bool, +} + +/// Stats collected by the connected components analysis +/// +/// Disabled analyses may be aliased to labels +#[derive(Clone, Debug)] +pub struct ConnectedStats { + /// Total area of each component + pub area: Tensor, + /// Topmost y coordinate in the component + pub top: Tensor, + /// Leftmost x coordinate in the component + pub left: Tensor, + /// Rightmost x coordinate in the component + pub right: Tensor, + /// Bottommost y coordinate in the component + pub bottom: Tensor, + /// Scalar tensor of the max label + pub max_label: Tensor, +} + +/// Primitive version of [`ConnectedStats`], to be returned by the backend +pub struct ConnectedStatsPrimitive { + /// Total area of each component + pub area: IntTensor, + /// Leftmost x coordinate in the component + pub left: IntTensor, + /// Topmost y coordinate in the component + pub top: IntTensor, + /// Rightmost x coordinate in the component + pub right: IntTensor, + /// Bottommost y coordinate in the component + pub bottom: IntTensor, + /// Scalar tensor of the max label + pub max_label: IntTensor, +} + +impl From> for ConnectedStats { + fn from(value: ConnectedStatsPrimitive) -> Self { + ConnectedStats { + area: Tensor::from_primitive(value.area), + top: Tensor::from_primitive(value.top), + left: Tensor::from_primitive(value.left), + right: Tensor::from_primitive(value.right), + bottom: Tensor::from_primitive(value.bottom), + max_label: Tensor::from_primitive(value.max_label), + } + } +} + +impl Default for ConnectedStatsOptions { + fn default() -> Self { + Self::all() + } +} + +impl ConnectedStatsOptions { + /// Don't collect any stats + pub fn none() -> Self { + Self { + bounds_enabled: false, + max_label_enabled: false, + compact_labels: false, + } + } + + /// Collect all stats + pub fn all() -> Self { + Self { + bounds_enabled: true, + max_label_enabled: true, + compact_labels: true, + } + } +} + +/// Vision operations, implemented by each backend +pub trait VisionOps { + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity - returns a tensor of the component label of each pixel. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] + fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { + cpu::connected_components::(img, connectivity) + } + + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity and collects statistics on each component - returns a tensor of the component + /// label of each pixel, along with stats collected for each component. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] + fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + cpu::connected_components_with_stats(img, connectivity, opts) + } +} diff --git a/crates/burn-vision/src/ops/mod.rs b/crates/burn-vision/src/ops/mod.rs new file mode 100644 index 0000000000..cbcb6ac7e7 --- /dev/null +++ b/crates/burn-vision/src/ops/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub use base::*; diff --git a/crates/burn-vision/src/tensor.rs b/crates/burn-vision/src/tensor.rs new file mode 100644 index 0000000000..5b381170a9 --- /dev/null +++ b/crates/burn-vision/src/tensor.rs @@ -0,0 +1,39 @@ +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; + +use crate::{ConnectedStats, ConnectedStatsOptions, Connectivity, VisionOps}; + +/// Connected components tensor extensions +pub trait ConnectedComponents { + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity - returns a tensor of the component label of each pixel. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] + fn connected_components(self, connectivity: Connectivity) -> Tensor; + + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity and collects statistics on each component - returns a tensor of the component + /// label of each pixel, along with stats collected for each component. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] + fn connected_components_with_stats( + self, + connectivity: Connectivity, + options: ConnectedStatsOptions, + ) -> (Tensor, ConnectedStats); +} + +impl> ConnectedComponents for Tensor { + fn connected_components(self, connectivity: Connectivity) -> Tensor { + Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity)) + } + + fn connected_components_with_stats( + self, + connectivity: Connectivity, + options: ConnectedStatsOptions, + ) -> (Tensor, ConnectedStats) { + let (labels, stats) = + B::connected_components_with_stats(self.into_primitive(), connectivity, options); + (Tensor::from_primitive(labels), stats.into()) + } +} diff --git a/crates/burn-vision/src/tests/connected_components.rs b/crates/burn-vision/src/tests/connected_components.rs new file mode 100644 index 0000000000..368054149c --- /dev/null +++ b/crates/burn-vision/src/tests/connected_components.rs @@ -0,0 +1,185 @@ +#[burn_tensor_testgen::testgen(connected_components)] +mod tests { + use std::collections::HashMap; + + use super::*; + use burn_tensor::TensorData; + use burn_vision::{as_type, ConnectedComponents, ConnectedStatsOptions, Connectivity}; + + fn space_invader() -> [[IntType; 14]; 9] { + as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1], + [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ]) + } + + #[test] + fn should_support_8_connectivity() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); + + let output = tensor.connected_components(Connectivity::Eight); + let expected = space_invader(); // All pixels are in the same group for 8-connected + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + } + + #[test] + fn should_support_8_connectivity_with_stats() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); + + let (output, stats) = tensor + .connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all()); + let expected = space_invader(); // All pixels are in the same group for 8-connected + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + let (area, left, top, right, bottom) = ( + stats.area.slice([0..1, 1..2]).into_data(), + stats.left.slice([0..1, 1..2]).into_data(), + stats.top.slice([0..1, 1..2]).into_data(), + stats.right.slice([0..1, 1..2]).into_data(), + stats.bottom.slice([0..1, 1..2]).into_data(), + ); + + output.into_data().assert_eq(&expected.into_data(), false); + + area.assert_eq(&TensorData::from([[58]]), false); + left.assert_eq(&TensorData::from([[0]]), false); + top.assert_eq(&TensorData::from([[0]]), false); + right.assert_eq(&TensorData::from([[13]]), false); + bottom.assert_eq(&TensorData::from([[8]]), false); + stats + .max_label + .into_data() + .assert_eq(&TensorData::from([1]), false); + } + + #[test] + fn should_support_4_connectivity() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); + + let output = tensor.connected_components(Connectivity::Four); + let expected = as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0], + [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0], + [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0], + [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5], + [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5], + [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5], + [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0], + ]); + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + } + + #[test] + fn should_support_4_connectivity_with_stats() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); + + let (output, stats) = tensor + .connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all()); + let expected = as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0], + [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0], + [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0], + [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5], + [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5], + [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5], + [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0], + ]); + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + // Slice off background and limit to compacted labels + let (area, left, top, right, bottom) = ( + stats.area.slice([0..1, 1..6]).into_data(), + stats.left.slice([0..1, 1..6]).into_data(), + stats.top.slice([0..1, 1..6]).into_data(), + stats.right.slice([0..1, 1..6]).into_data(), + stats.bottom.slice([0..1, 1..6]).into_data(), + ); + + output.into_data().assert_eq(&expected.into_data(), false); + + area.assert_eq(&TensorData::from([[1, 1, 46, 5, 5]]), false); + left.assert_eq(&TensorData::from([[3, 10, 1, 0, 12]]), false); + top.assert_eq(&TensorData::from([[0, 0, 1, 5, 5]]), false); + right.assert_eq(&TensorData::from([[3, 10, 12, 1, 13]]), false); + bottom.assert_eq(&TensorData::from([[0, 0, 8, 7, 7]]), false); + stats + .max_label + .into_data() + .assert_eq(&TensorData::from([5]), false); + } + + /// Normalize labels to sequential since actual labels aren't required to be contiguous and + /// different algorithms can return different numbers even if correct + fn normalize_labels(mut labels: TensorData) -> TensorData { + let mut next_label = 0; + let mut mappings = HashMap::::default(); + let data = labels.as_mut_slice::().unwrap(); + for label in data { + if *label != 0 { + let relabel = mappings.entry(*label).or_insert_with(|| { + next_label += 1; + next_label + }); + *label = *relabel; + } + } + labels + } + + fn normalize_stats( + area: TensorData, + left: TensorData, + top: TensorData, + right: TensorData, + bottom: TensorData, + ) -> (TensorData, TensorData, TensorData, TensorData, TensorData) { + let batches = area.shape[0]; + + let area = area.as_slice::().unwrap(); + let left = left.as_slice::().unwrap(); + let top = top.as_slice::().unwrap(); + let right = right.as_slice::().unwrap(); + let bottom = bottom.as_slice::().unwrap(); + + let mut area_new = vec![]; + let mut left_new = vec![]; + let mut top_new = vec![]; + let mut right_new = vec![]; + let mut bottom_new = vec![]; + + for (label, area) in area.iter().enumerate() { + if *area != 0 { + area_new.push(*area); + left_new.push(left[label]); + top_new.push(top[label]); + right_new.push(right[label]); + bottom_new.push(bottom[label]); + } + } + + let shape = [batches, area_new.len() / batches]; + + ( + TensorData::new(area_new, shape.clone()), + TensorData::new(left_new, shape.clone()), + TensorData::new(top_new, shape), + TensorData::new(right_new, shape.clone()), + TensorData::new(bottom_new, shape.clone()), + ) + } +} diff --git a/crates/burn-vision/src/tests/mod.rs b/crates/burn-vision/src/tests/mod.rs new file mode 100644 index 0000000000..11851577ed --- /dev/null +++ b/crates/burn-vision/src/tests/mod.rs @@ -0,0 +1,37 @@ +mod connected_components; + +#[macro_export] +macro_rules! testgen_all { + () => { + use burn_tensor::{Bool, Float, Int}; + + pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensorBool = burn_tensor::Tensor; + + pub mod vision { + pub use super::*; + + pub type IntType = ::IntElem; + + burn_vision::testgen_connected_components!(); + } + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! as_type { + ($ty:ident: [$($elem:tt),*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: [$($elem:tt,)*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: $elem:expr) => { + { + use cubecl::prelude::*; + + $ty::new($elem) + } + }; +} diff --git a/crates/burn-vision/tests/main.rs b/crates/burn-vision/tests/main.rs new file mode 100644 index 0000000000..6bd8dbfb96 --- /dev/null +++ b/crates/burn-vision/tests/main.rs @@ -0,0 +1,27 @@ +#[cfg(all(test, feature = "test-cpu"))] +mod tests_cpu { + pub type TestBackend = burn_ndarray::NdArray; + + burn_vision::testgen_all!(); +} + +#[cfg(all(test, feature = "test-wgpu"))] +mod tests_wgpu { + pub type TestBackend = burn_wgpu::Wgpu; + + burn_vision::testgen_all!(); +} + +#[cfg(all(test, feature = "test-vulkan"))] +mod tests_wgpu { + pub type TestBackend = burn_wgpu::Vulkan; + + burn_vision::testgen_all!(); +} + +#[cfg(all(test, feature = "test-cuda"))] +mod tests_cuda { + pub type TestBackend = burn_cuda::Cuda; + + burn_vision::testgen_all!(); +} diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index e0c247172d..3f4e2dcc0d 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -21,12 +21,12 @@ std = ["burn-jit/std", "cubecl/std"] template = ["burn-jit/template", "cubecl/template"] # Backends -webgpu = ["cubecl-wgsl"] vulkan = ["cubecl-spirv"] +webgpu = ["cubecl-wgsl"] # Compilers -cubecl-wgsl = [] cubecl-spirv = ["cubecl/wgpu-spirv"] +cubecl-wgsl = [] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 5b94b2909e..dbd7eab842 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -67,6 +67,15 @@ pub(crate) fn handle_command( "std with features: test-tch,record-item-custom-serde", )?; + // burn-vision + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "test-cpu"], + None, + None, + "std cpu", + )?; + if std::env::var("DISABLE_WGPU").is_err() { helpers::custom_crates_tests( vec!["burn-core"], @@ -75,6 +84,13 @@ pub(crate) fn handle_command( None, "std wgpu", )?; + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "test-wgpu"], + None, + None, + "std wgpu", + )?; // Vulkan isn't available on MacOS #[cfg(not(target_os = "macos"))] if std::env::var("DISABLE_WGPU_SPIRV").is_err() { @@ -85,6 +101,13 @@ pub(crate) fn handle_command( None, "std vulkan", )?; + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "vulkan"], + None, + None, + "std vulkan", + )?; } }