Skip to content

Commit

Permalink
WIP but functioning adapters implementation. Postprocessing part of a…
Browse files Browse the repository at this point in the history
…dapters now. Changed "id" to "identifier". Turned Metrics into an nn.Module container to remove the need for PatchedModuleDict
  • Loading branch information
ibro45 committed Jan 6, 2025
1 parent 8982de4 commit 9ba8987
Show file tree
Hide file tree
Showing 18 changed files with 512 additions and 312 deletions.
272 changes: 272 additions & 0 deletions lighter/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
from typing import Any, Callable

from abc import ABC

from lighter.utils.misc import ensure_list


class TransformAdapter(ABC):
"""
An abstract base class for applying transform functions to data.
"""

def _transform(self, data: Any, transforms: Callable | list[Callable]) -> Any:
"""
Applies a list of transform functions to the data.
Args:
data: The data to be transformed.
transforms: A single transform function or a list of functions.
Returns:
The transformed data.
Raises:
ValueError: If any transform is not callable.
"""
for transform in ensure_list(transforms):
if callable(transform):
data = transform(data)
else:
raise ValueError(f"Invalid transform type for transform: {transform}")
return data


class BatchAdapter:
def __init__(
self,
input_accessor: int | str | Callable | None = None,
target_accessor: int | str | Callable | None = None,
identifier_accessor: int | str | Callable | None = None,
):
"""
Initializes BatchAdapter with accessors for input, target, and id.
Args:
input_accessor: Accessor for the input data. Can be an index (for lists/tuples),
a key (for dictionaries), or a callable.
target_accessor: Accessor for the target data. Can be an index (for lists/tuples),
a key (for dictionaries), or a callable.
identifier_accessor: Accessor for the identifier data. Can be an index (for lists/tuples),
a key (for dictionaries), or a callable.
"""
self.input_accessor = input_accessor
self.target_accessor = target_accessor
self.identifier_accessor = identifier_accessor

def identifier(self, data: Any) -> Any:
# TODO - see what to do regarding the default value, old lighter would return None if id doesnt exist
return self._access_value(data, self.identifier_accessor)

def input(self, data: Any) -> Any:
return self._access_value(data, self.input_accessor)

def target(self, data: Any) -> Any:
return self._access_value(data, self.target_accessor)

def _access_value(self, data: Any, accessor: int | str | Callable) -> Any:
"""
Accesses a value from the data using the provided accessor.
Args:
data: The data to access the value from.
accessor: The accessor to use. Can be an index (for lists/tuples),
a key (for dictionaries), or a callable.
Returns:
The accessed value.
Raises:
ValueError: If the accessor type or data structure is invalid.
"""
if accessor is None:
return data
elif isinstance(accessor, int) and isinstance(data, (tuple, list)):
return data[accessor]
elif isinstance(accessor, str) and isinstance(data, dict):
return data.get(accessor)
elif callable(accessor):
return accessor(data)
else:
raise ValueError(f"Invalid accessor {accessor} of type {type(accessor)} for data type {type(data)}.")


class FunctionAdapter(TransformAdapter):
"""
A generic adapter for applying functions (criterion or metrics) to data.
"""

def __init__(
self,
input_argument: int | str | None = None,
target_argument: int | str | None = None,
pred_argument: int | str | None = None,
input_transforms: list[Callable] | None = None,
target_transforms: list[Callable] | None = None,
pred_transforms: list[Callable] | None = None,
):
"""
Initializes FunctionAdapter with arguments and transforms for input, target, and prediction.
Args:
input_argument: The argument name for the input data.
target_argument: The argument name for the target data.
pred_argument: The argument name for the prediction data.
input_transforms: A list of transforms to apply to the input data.
target_transforms: A list of transforms to apply to the target data.
pred_transforms: A list of transforms to apply to the prediction data.
Raises:
ValueError: If transforms are provided but the corresponding argument is None.
"""
if input_argument is None and input_transforms is not None:
raise ValueError("Input transforms provided but input_argument is None")
if target_argument is None and target_transforms is not None:
raise ValueError("Target transforms provided but target_argument is None")
if pred_argument is None and pred_transforms is not None:
raise ValueError("Pred transforms provided but pred_argument is None")

self.input_argument = input_argument
self.target_argument = target_argument
self.pred_argument = pred_argument

self.input_transforms = input_transforms
self.target_transforms = target_transforms
self.pred_transforms = pred_transforms

def __call__(self, func: Callable, input: Any, target: Any, pred: Any) -> Any:
"""
Applies the given function to the input, target, and prediction data.
Args:
func: The function to apply.
input: The input data.
target: The target data.
pred: The prediction data.
Returns:
The result of the function call.
"""
args = []
kwargs = {}
if self.input_argument is not None:
input = self._transform(input, self.input_transforms)
if isinstance(self.input_argument, int):
args.insert(self.input_argument, input)
else:
kwargs[self.input_argument] = input

if self.target_argument is not None:
target = self._transform(target, self.target_transforms)
if isinstance(self.target_argument, int):
args.insert(self.target_argument, target)
else:
kwargs[self.target_argument] = target

if self.pred_argument is not None:
pred = self._transform(pred, self.pred_transforms)
if isinstance(self.pred_argument, int):
args.insert(self.pred_argument, pred)
else:
kwargs[self.pred_argument] = pred

return func(*args, **kwargs)


class CriterionAdapter(FunctionAdapter):
def __call__(self, criterion: Callable, input: Any, target: Any, pred: Any) -> Any:
"""
Applies the criterion to the input, target, and prediction data.
Args:
criterion: The criterion (loss function) to apply.
input: The input data.
target: The target data.
pred: The prediction data.
Returns:
The result of the criterion call.
"""
return super().__call__(criterion, input, target, pred)


class MetricsAdapter(FunctionAdapter):
"""
An adapter specifically for metrics calculations.
"""

def __call__(self, metrics: Callable, input: Any, target: Any, pred: Any) -> Any:
"""
Calculates metrics using the provided function and data.
Args:
metrics: The metrics function to apply.
input: The input data.
target: The target data.
pred: The prediction data.
Returns:
The result of the metrics calculation.
"""
return super().__call__(metrics, input, target, pred)


class LoggingAdapter(TransformAdapter):
"""
An adapter for applying transformations to data before logging.
"""

def __init__(
self,
input_transforms: list[Callable] | None = None,
target_transforms: list[Callable] | None = None,
pred_transforms: list[Callable] | None = None,
):
"""
Initializes LoggingAdapter with transforms for input, target, and prediction.
Args:
input_transforms: A list of transforms to apply to the input data.
target_transforms: A list of transforms to apply to the target data.
pred_transforms: A list of transforms to apply to the prediction data.
"""

self.input_transforms = input_transforms
self.target_transforms = target_transforms
self.pred_transforms = pred_transforms

def input(self, data: Any):
"""
Transforms the input data for logging.
Args:
data: The input data.
Returns:
The transformed input data.
"""
return self._transform(data, self.input_transforms)

def target(self, data: Any):
"""
Transforms the target data for logging.
Args:
data: The target data.
Returns:
The transformed target data.
"""
return self._transform(data, self.target_transforms)

def pred(self, data: Any):
"""
Transforms the prediction data for logging.
Args:
data: The prediction data.
Returns:
The transformed prediction data.
"""
return self._transform(data, self.pred_transforms)
16 changes: 9 additions & 7 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from lighter import System
from lighter.utils.types import Data, Stage
from lighter.utils.types.enums import Data, Stage


class BaseWriter(ABC, Callback):
Expand Down Expand Up @@ -54,7 +54,7 @@ def writers(self) -> dict[str, Callable]:
"""

@abstractmethod
def write(self, tensor: Tensor, id: int) -> None:
def write(self, tensor: Tensor, identifier: int) -> None:
"""
Method to define how a tensor should be saved. The input tensor will be a single tensor without
the batch dimension.
Expand All @@ -64,7 +64,7 @@ def write(self, tensor: Tensor, id: int) -> None:
Args:
tensor (Tensor): Tensor, without the batch dimension, to be saved.
id (int): Identifier for the tensor, can be used for naming files or adding table records.
identifier (int): Identifier for the tensor, can be used for naming files or adding table records.
"""

def setup(self, trainer: Trainer, pl_module: System, stage: str) -> None:
Expand Down Expand Up @@ -117,14 +117,16 @@ def on_predict_batch_end(
dataloader_idx (int): The index of the dataloader.
"""
# If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported.
if outputs[Data.ID] is None:
if outputs[Data.IDENTIFIER] is None:
batch_size = len(outputs[Data.PRED])
world_size = trainer.world_size
outputs[Data.ID] = list(range(self._pred_counter, self._pred_counter + batch_size * world_size, world_size))
outputs[Data.IDENTIFIER] = list(
range(self._pred_counter, self._pred_counter + batch_size * world_size, world_size)
)
self._pred_counter += batch_size * world_size

for id, pred in zip(outputs[Data.ID], outputs[Data.PRED]):
self.write(tensor=pred, id=id)
for identifier, pred in zip(outputs[Data.IDENTIFIER], outputs[Data.PRED]):
self.write(tensor=pred, identifier=identifier)

# Clear the predictions to save CPU memory. https://github.com/Lightning-AI/pytorch-lightning/issues/19398
trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]
Expand Down
6 changes: 3 additions & 3 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def writers(self) -> dict[str, Callable]:
"itk_nifti": partial(write_itk_image, suffix=".nii.gz"),
}

def write(self, tensor: Tensor, id: int | str) -> None:
def write(self, tensor: Tensor, identifier: int | str) -> None:
"""
Writes the tensor to a file using the specified writer.
Args:
tensor: The tensor to write.
id: Identifier for naming the file.
identifier: Identifier for naming the file.
"""
if not self.path.is_dir():
raise RuntimeError(f"FileWriter expects a directory path, got {self.path}")

# Determine the path for the file based on prediction count. The suffix must be added by the writer function.
path = self.path / str(id)
path = self.path / str(identifier)
# Write the tensor to the file.
self.writer(path, tensor)

Expand Down
10 changes: 5 additions & 5 deletions lighter/callbacks/writer/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def writers(self) -> dict[str, Callable]:
"tensor": lambda tensor: tensor.item() if tensor.numel() == 1 else tensor.tolist(),
}

def write(self, tensor: Any, id: int | str) -> None:
def write(self, tensor: Any, identifier: int | str) -> None:
"""
Writes the tensor as a table record using the specified writer.
Args:
tensor: The tensor to record. Should not have a batch dimension.
id: Identifier for the record.
identifier: Identifier for the record.
"""
self.csv_records.append({"id": id, "pred": self.writer(tensor)})
self.csv_records.append({"identifier": identifier, "pred": self.writer(tensor)})

def on_predict_epoch_end(self, trainer: Trainer, pl_module: System) -> None:
"""
Expand All @@ -63,10 +63,10 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: System) -> None:
if trainer.is_global_zero:
df = pd.DataFrame(self.csv_records)
try:
df = df.sort_values("id")
df = df.sort_values("identifier")
except TypeError:
pass
df = df.set_index("id")
df = df.set_index("identifier")
df.to_csv(self.path)

# Clear the records after saving
Expand Down
Loading

0 comments on commit 9ba8987

Please sign in to comment.