Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Resume dictdataloader support for Trainer #627

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions qadence/ml_tools/callbacks/writer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from typing import Any, Callable, Union
from uuid import uuid4

import mlflow
from matplotlib.figure import Figure
from mlflow.entities import Run
from mlflow.models import infer_signature
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import OptimizeResult
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
from qadence.types import ExperimentTrackingTool

logger = getLogger("ml_tools")
Expand All @@ -43,7 +40,7 @@ class BaseWriter(ABC):
log_model(model, dataloader): Logs the model and any relevant information.
"""

run: Run # [attr-defined]
run: Any # [attr-defined]

@abstractmethod
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
Expand Down Expand Up @@ -104,18 +101,18 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and associated data.

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
raise NotImplementedError("Writers must implement a log_model method.")

Expand Down Expand Up @@ -231,9 +228,9 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model.
Expand All @@ -242,9 +239,9 @@ def log_model(

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")

Expand All @@ -259,6 +256,14 @@ class MLFlowWriter(BaseWriter):
"""

def __init__(self) -> None:
try:
from mlflow.entities import Run
except ImportError:
raise ImportError(
"mlflow is not installed. Please install qadence with the mlflow feature: "
"`pip install qadence[mlflow]`."
)

self.run: Run
self.mlflow: ModuleType

Expand All @@ -274,6 +279,8 @@ def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType
Returns:
mlflow: The MLflow module instance.
"""
import mlflow
mlahariya marked this conversation as resolved.
Show resolved Hide resolved

self.mlflow = mlflow
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
Expand Down Expand Up @@ -356,17 +363,21 @@ def plot(
"Please call the 'writer.open()' method before writing"
)

def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any:
def get_signature_from_dataloader(
self, model: Module, dataloader: DataLoader | DictDataLoader | None
) -> Any:
"""
Infers the signature of the model based on the input data from the dataloader.

Args:
model (Module): The model to use for inference.
dataloader (DataLoader | None): DataLoader for model inputs.
dataloader (DataLoader | DictDataLoader | None): DataLoader for model inputs.

Returns:
Optional[Any]: The inferred signature, if available.
"""
from mlflow.models import infer_signature

if dataloader is None:
return None

Expand All @@ -384,18 +395,18 @@ def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader |
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and its signature to MLflow using the provided data loaders.

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
if not self.mlflow:
raise RuntimeError(
Expand Down
59 changes: 33 additions & 26 deletions qadence/ml_tools/train_utils/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
from nevergrad.optimization.base import Optimizer as NGOptimizer
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, TensorDataset

from qadence.ml_tools.callbacks import CallbacksManager
from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import InfiniteTensorDataset
from qadence.ml_tools.data import DictDataLoader
from qadence.ml_tools.loss import get_loss_fn
from qadence.ml_tools.optimize_step import optimize_step
from qadence.ml_tools.parameters import get_parameters
Expand Down Expand Up @@ -42,9 +42,9 @@ class BaseTrainer:
model (nn.Module): The neural network model.
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
config (TrainConfig): The configuration settings for training.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.

optimize_step (Callable): Function for performing an optimization step.
loss_fn (Callable | str ]): loss function to use. Default loss function
Expand All @@ -69,9 +69,9 @@ def __init__(
config: TrainConfig,
loss_fn: str | Callable = "mse",
optimize_step: Callable = optimize_step,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
max_batches: int | None = None,
):
"""
Expand All @@ -86,11 +86,11 @@ def __init__(
str input to be specified to use a default loss function.
currently supported loss functions: 'mse', 'cross_entropy'.
If not specified, default mse loss will be used.
train_dataloader (DataLoader | None): DataLoader for training data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
If the model does not need data to evaluate loss, no dataset
should be provided.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
max_batches (int | None): Maximum number of batches to process per epoch.
This is only valid in case of finite TensorDataset dataloaders.
if max_batches is not None, the maximum number of batches used will
Expand All @@ -100,9 +100,9 @@ def __init__(
self._model: nn.Module
self._optimizer: optim.Optimizer | NGOptimizer | None
self._config: TrainConfig
self._train_dataloader: DataLoader | None = None
self._val_dataloader: DataLoader | None = None
self._test_dataloader: DataLoader | None = None
self._train_dataloader: DataLoader | DictDataLoader | None = None
self._val_dataloader: DataLoader | DictDataLoader | None = None
self._test_dataloader: DataLoader | DictDataLoader | None = None

self.config = config
self.model = model
Expand Down Expand Up @@ -311,7 +311,7 @@ def config(self, value: TrainConfig) -> None:
self.callback_manager = CallbacksManager(value)
self.config_manager = ConfigManager(value)

def _compute_num_batches(self, dataloader: DataLoader) -> int:
def _compute_num_batches(self, dataloader: DataLoader | DictDataLoader) -> int:
"""
Computes the number of batches for the given DataLoader.

Expand All @@ -321,34 +321,41 @@ def _compute_num_batches(self, dataloader: DataLoader) -> int:
"""
if dataloader is None:
return 1
dataset = dataloader.dataset
if isinstance(dataset, InfiniteTensorDataset):
return 1
if isinstance(dataloader, DictDataLoader):
dataloader_name, dataloader_value = list(dataloader.dataloaders.items())[0]
dataset = dataloader_value.dataset
batch_size = dataloader_value.batch_size
else:
n_batches = int(
(dataset.tensors[0].size(0) + dataloader.batch_size - 1) // dataloader.batch_size
)
dataset = dataloader.dataset
batch_size = dataloader.batch_size

if isinstance(dataset, TensorDataset):
n_batches = int((dataset.tensors[0].size(0) + batch_size - 1) // batch_size)
return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
else:
return 1

def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None:
def _validate_dataloader(
self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
) -> None:
"""
Validates the type of the DataLoader and raises errors for unsupported types.

Args:
dataloader (DataLoader): The DataLoader to validate.
dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
dataloader_type (str): The type of DataLoader ("train", "val", or "test").
"""
if dataloader is not None:
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
raise NotImplementedError(
f"Unsupported dataloader type: {type(dataloader)}."
"The dataloader must be an instance of DataLoader."
)
if dataloader_type == "val" and self.config.val_every > 0:
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
raise ValueError(
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
"must be an instance of `DataLoader`."
"must be an instance of `DataLoader` or `DictDataLoader`."
)

@staticmethod
Expand Down
Loading
Loading