From 3290b2411e269b1f58d92871ea36c8c86b0920a6 Mon Sep 17 00:00:00 2001 From: leavauchier <120112647+leavauchier@users.noreply.github.com> Date: Mon, 23 Jan 2023 15:23:54 +0100 Subject: [PATCH] Optimize memory usage during interpolation+saving (#50) * Delete variables with high memory usage * Delay las loading when saving results Co-authored-by: GLiegard --- myria3d/models/interpolation.py | 24 +++++++++++++++++------- myria3d/pctl/dataset/utils.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/myria3d/models/interpolation.py b/myria3d/models/interpolation.py index ae9d6832..b5c85fca 100644 --- a/myria3d/models/interpolation.py +++ b/myria3d/models/interpolation.py @@ -8,7 +8,7 @@ from torch.distributions import Categorical from torch_scatter import scatter_sum -from myria3d.pctl.dataset.utils import get_pdal_reader +from myria3d.pctl.dataset.utils import get_pdal_info_metadata, get_pdal_reader log = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def store_predictions(self, logits, idx_in_original_cloud) -> None: self.idx_in_full_cloud_list += idx_in_original_cloud @torch.no_grad() - def reduce_predicted_logits(self, las) -> torch.Tensor: + def reduce_predicted_logits(self, nb_points) -> torch.Tensor: """Interpolate logits to points without predictions using an inverse-distance weightning scheme. Returns: @@ -100,7 +100,7 @@ def reduce_predicted_logits(self, las) -> torch.Tensor: # We scatter_sum logits based on idx, in case there are multiple predictions for a point. # scatter_sum reorders logitsbased on index,they therefore match las order. - reduced_logits = torch.zeros((len(las), logits.size(1))) + reduced_logits = torch.zeros((nb_points, logits.size(1))) scatter_sum(logits, torch.from_numpy(idx_in_full_cloud), out=reduced_logits, dim=0) # reduced_logits contains logits ordered by their idx in original cloud ! # Warning : some points may not contain any predictions if they were in small areas. @@ -120,22 +120,32 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str) -> str: """ basename = os.path.basename(raw_path) - las = self.load_full_las_for_update(src_las=raw_path) - logits = self.reduce_predicted_logits(las) + # Read number of points only from las metadata in order to minimize memory usage + nb_points = get_pdal_info_metadata(raw_path)["count"] + logits = self.reduce_predicted_logits(nb_points) probas = torch.nn.Softmax(dim=1)(logits) + + if self.predicted_classification_channel: + preds = torch.argmax(logits, dim=1) + preds = np.vectorize(self.reverse_mapper.get)(preds) + + del logits + + # Read las after fetching all information to write into it + las = self.load_full_las_for_update(src_las=raw_path) + for idx, class_name in enumerate(self.classification_dict.values()): if class_name in self.probas_to_save: las[class_name] = probas[:, idx] if self.predicted_classification_channel: - preds = torch.argmax(logits, dim=1) - preds = np.vectorize(self.reverse_mapper.get)(preds) las[self.predicted_classification_channel] = preds log.info( f"Saving predicted classes to channel {self.predicted_classification_channel}." "Channel name can be changed by setting `predict.interpolator.predicted_classification_channel`." ) + del preds if self.entropy_channel: las[self.entropy_channel] = Categorical(probs=probas).entropy() diff --git a/myria3d/pctl/dataset/utils.py b/myria3d/pctl/dataset/utils.py index 586a79e5..1a23f397 100644 --- a/myria3d/pctl/dataset/utils.py +++ b/myria3d/pctl/dataset/utils.py @@ -1,5 +1,7 @@ import glob +import json import math +import subprocess as sp from numbers import Number from typing import Dict, List, Literal, Union @@ -82,6 +84,23 @@ def get_pdal_reader(las_path: str) -> pdal.Reader.las: ) +def get_pdal_info_metadata(las_path: str) -> Dict: + """Read las metadata using pdal info + Args: + las_path (str): input LAS path to read. + Returns: + (dict): dictionary containing metadata from the las file + """ + r = sp.run(["pdal", "info", "--metadata", las_path], capture_output=True) + if r.returncode == 1: + msg = r.stderr.decode() + raise RuntimeError(msg) + + output = r.stdout.decode() + json_info = json.loads(output) + + return json_info["metadata"] + # hdf5, iterable