Skip to content

Commit

Permalink
Optimize memory usage during interpolation+saving (#50)
Browse files Browse the repository at this point in the history
* Delete variables with high memory usage

* Delay las loading when saving results

Co-authored-by: GLiegard <[email protected]>
  • Loading branch information
leavauchier and gliegard authored Jan 23, 2023
1 parent af17f90 commit 3290b24
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
24 changes: 17 additions & 7 deletions myria3d/models/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.distributions import Categorical
from torch_scatter import scatter_sum

from myria3d.pctl.dataset.utils import get_pdal_reader
from myria3d.pctl.dataset.utils import get_pdal_info_metadata, get_pdal_reader

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +84,7 @@ def store_predictions(self, logits, idx_in_original_cloud) -> None:
self.idx_in_full_cloud_list += idx_in_original_cloud

@torch.no_grad()
def reduce_predicted_logits(self, las) -> torch.Tensor:
def reduce_predicted_logits(self, nb_points) -> torch.Tensor:
"""Interpolate logits to points without predictions using an inverse-distance weightning scheme.
Returns:
Expand All @@ -100,7 +100,7 @@ def reduce_predicted_logits(self, las) -> torch.Tensor:

# We scatter_sum logits based on idx, in case there are multiple predictions for a point.
# scatter_sum reorders logitsbased on index,they therefore match las order.
reduced_logits = torch.zeros((len(las), logits.size(1)))
reduced_logits = torch.zeros((nb_points, logits.size(1)))
scatter_sum(logits, torch.from_numpy(idx_in_full_cloud), out=reduced_logits, dim=0)
# reduced_logits contains logits ordered by their idx in original cloud !
# Warning : some points may not contain any predictions if they were in small areas.
Expand All @@ -120,22 +120,32 @@ def reduce_predictions_and_save(self, raw_path: str, output_dir: str) -> str:
"""
basename = os.path.basename(raw_path)
las = self.load_full_las_for_update(src_las=raw_path)
logits = self.reduce_predicted_logits(las)
# Read number of points only from las metadata in order to minimize memory usage
nb_points = get_pdal_info_metadata(raw_path)["count"]
logits = self.reduce_predicted_logits(nb_points)

probas = torch.nn.Softmax(dim=1)(logits)

if self.predicted_classification_channel:
preds = torch.argmax(logits, dim=1)
preds = np.vectorize(self.reverse_mapper.get)(preds)

del logits

# Read las after fetching all information to write into it
las = self.load_full_las_for_update(src_las=raw_path)

for idx, class_name in enumerate(self.classification_dict.values()):
if class_name in self.probas_to_save:
las[class_name] = probas[:, idx]

if self.predicted_classification_channel:
preds = torch.argmax(logits, dim=1)
preds = np.vectorize(self.reverse_mapper.get)(preds)
las[self.predicted_classification_channel] = preds
log.info(
f"Saving predicted classes to channel {self.predicted_classification_channel}."
"Channel name can be changed by setting `predict.interpolator.predicted_classification_channel`."
)
del preds

if self.entropy_channel:
las[self.entropy_channel] = Categorical(probs=probas).entropy()
Expand Down
19 changes: 19 additions & 0 deletions myria3d/pctl/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import json
import math
import subprocess as sp
from numbers import Number
from typing import Dict, List, Literal, Union

Expand Down Expand Up @@ -82,6 +84,23 @@ def get_pdal_reader(las_path: str) -> pdal.Reader.las:
)


def get_pdal_info_metadata(las_path: str) -> Dict:
"""Read las metadata using pdal info
Args:
las_path (str): input LAS path to read.
Returns:
(dict): dictionary containing metadata from the las file
"""
r = sp.run(["pdal", "info", "--metadata", las_path], capture_output=True)
if r.returncode == 1:
msg = r.stderr.decode()
raise RuntimeError(msg)

output = r.stdout.decode()
json_info = json.loads(output)

return json_info["metadata"]

# hdf5, iterable


Expand Down

0 comments on commit 3290b24

Please sign in to comment.