diff --git a/myria3d/models/modules/randla_net.py b/myria3d/models/modules/randla_net.py index c8d12079..77d24477 100644 --- a/myria3d/models/modules/randla_net.py +++ b/myria3d/models/modules/randla_net.py @@ -409,7 +409,7 @@ def knn_compact( ) # Get back to compact shape - batch_size_y = len(pos_y[1]) + batch_size_y = pos_y.size(1) compact_shape_y = (num_graphs, batch_size_y, -1) x_idx = x_idx_long.view(compact_shape_y) y_idx = y_idx_long.view(compact_shape_y) diff --git a/tests/myria3d/models/modules/test_randla_net.py b/tests/myria3d/models/modules/test_randla_net.py index 86446dd6..ee7f6a32 100644 --- a/tests/myria3d/models/modules/test_randla_net.py +++ b/tests/myria3d/models/modules/test_randla_net.py @@ -1,10 +1,17 @@ +import pytest import torch from torch_geometric.data import Batch from myria3d.models.modules.randla_net import RandLANet -def test_fake_run_randlanet(): - """Documents expected data format and make a forward pass with RandLa-Net""" +@pytest.mark.parametrize("num_graphs", [1, 4]) +def test_fake_run_randlanet(num_graphs): + """Documents expected data format and make a forward pass with RandLa-Net + + Model pass with "batch_size=1" is a edge case that needs to pass to avoid unexpected crash due to incomplete batch at the + end of an inference. + + """ num_euclidian_dimensions = 3 num_features = 9 d_in = num_euclidian_dimensions + num_features @@ -17,7 +24,7 @@ def test_fake_run_randlanet(): "num_classes": num_classes, } batch = Batch() - batch.num_graphs = 4 + batch.num_graphs = 1 num_points = 12500 batch.pos = torch.rand((num_points * batch.num_graphs, num_euclidian_dimensions)) batch.x = torch.rand((num_points * batch.num_graphs, num_features))