Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Resume dictdataloader support for Trainer #627

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions qadence/ml_tools/callbacks/writer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.tensorboard import SummaryWriter

from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import OptimizeResult
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
from qadence.types import ExperimentTrackingTool

logger = getLogger("ml_tools")
Expand Down Expand Up @@ -104,18 +104,18 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and associated data.

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
raise NotImplementedError("Writers must implement a log_model method.")

Expand Down Expand Up @@ -231,9 +231,9 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model.
Expand All @@ -242,9 +242,9 @@ def log_model(

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")

Expand Down Expand Up @@ -356,13 +356,15 @@ def plot(
"Please call the 'writer.open()' method before writing"
)

def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any:
def get_signature_from_dataloader(
self, model: Module, dataloader: DataLoader | DictDataLoader | None
) -> Any:
"""
Infers the signature of the model based on the input data from the dataloader.

Args:
model (Module): The model to use for inference.
dataloader (DataLoader | None): DataLoader for model inputs.
dataloader (DataLoader | DictDataLoader | None): DataLoader for model inputs.

Returns:
Optional[Any]: The inferred signature, if available.
Expand All @@ -384,18 +386,18 @@ def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader |
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and its signature to MLflow using the provided data loaders.

Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
if not self.mlflow:
raise RuntimeError(
Expand Down
38 changes: 20 additions & 18 deletions qadence/ml_tools/train_utils/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from qadence.ml_tools.callbacks import CallbacksManager
from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import InfiniteTensorDataset
from qadence.ml_tools.data import DictDataLoader, InfiniteTensorDataset
from qadence.ml_tools.loss import get_loss_fn
from qadence.ml_tools.optimize_step import optimize_step
from qadence.ml_tools.parameters import get_parameters
Expand Down Expand Up @@ -42,9 +42,9 @@ class BaseTrainer:
model (nn.Module): The neural network model.
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
config (TrainConfig): The configuration settings for training.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.

optimize_step (Callable): Function for performing an optimization step.
loss_fn (Callable | str ]): loss function to use. Default loss function
Expand All @@ -69,9 +69,9 @@ def __init__(
config: TrainConfig,
loss_fn: str | Callable = "mse",
optimize_step: Callable = optimize_step,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
max_batches: int | None = None,
):
"""
Expand All @@ -86,11 +86,11 @@ def __init__(
str input to be specified to use a default loss function.
currently supported loss functions: 'mse', 'cross_entropy'.
If not specified, default mse loss will be used.
train_dataloader (DataLoader | None): DataLoader for training data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
If the model does not need data to evaluate loss, no dataset
should be provided.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
max_batches (int | None): Maximum number of batches to process per epoch.
This is only valid in case of finite TensorDataset dataloaders.
if max_batches is not None, the maximum number of batches used will
Expand All @@ -100,9 +100,9 @@ def __init__(
self._model: nn.Module
self._optimizer: optim.Optimizer | NGOptimizer | None
self._config: TrainConfig
self._train_dataloader: DataLoader | None = None
self._val_dataloader: DataLoader | None = None
self._test_dataloader: DataLoader | None = None
self._train_dataloader: DataLoader | DictDataLoader | None = None
self._val_dataloader: DataLoader | DictDataLoader | None = None
self._test_dataloader: DataLoader | DictDataLoader | None = None

self.config = config
self.model = model
Expand Down Expand Up @@ -330,25 +330,27 @@ def _compute_num_batches(self, dataloader: DataLoader) -> int:
)
return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches

def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None:
def _validate_dataloader(
self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
) -> None:
"""
Validates the type of the DataLoader and raises errors for unsupported types.

Args:
dataloader (DataLoader): The DataLoader to validate.
dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
dataloader_type (str): The type of DataLoader ("train", "val", or "test").
"""
if dataloader is not None:
if not isinstance(dataloader, DataLoader):
if not (isinstance(dataloader, DataLoader) or isinstance(dataloader, DictDataLoader)):
mlahariya marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
f"Unsupported dataloader type: {type(dataloader)}."
"The dataloader must be an instance of DataLoader."
)
if dataloader_type == "val" and self.config.val_every > 0:
if not isinstance(dataloader, DataLoader):
if not (isinstance(dataloader, DataLoader) or isinstance(dataloader, DictDataLoader)):
mlahariya marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
"must be an instance of `DataLoader`."
"must be an instance of `DataLoader` or `DictDataLoader`."
)

@staticmethod
Expand Down
28 changes: 15 additions & 13 deletions qadence/ml_tools/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import DataLoader

from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import OptimizeResult
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
from qadence.ml_tools.stages import TrainingStage

Expand Down Expand Up @@ -49,9 +49,9 @@ class Trainer(BaseTrainer):
model (nn.Module): The neural network model.
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
config (TrainConfig): The configuration settings for training.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.

optimize_step (Callable): Function for performing an optimization step.
loss_fn (Callable): loss function to use.
Expand Down Expand Up @@ -235,9 +235,9 @@ def __init__(
optimizer: optim.Optimizer | NGOptimizer | None,
config: TrainConfig,
loss_fn: str | Callable = "mse",
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
optimize_step: Callable = optimize_step,
device: torch_device | None = None,
dtype: torch_dtype | None = None,
Expand All @@ -252,9 +252,9 @@ def __init__(
config (TrainConfig): Training configuration object.
loss_fn (str | Callable ): Loss function used for training.
If not specified, default mse loss will be used.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for test data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
optimize_step (Callable): Function to execute an optimization step.
device (torch_device): Device to use for computation.
dtype (torch_dtype): Data type for computation.
Expand Down Expand Up @@ -285,7 +285,9 @@ def __init__(
self.data_dtype = float64 if (self.dtype == complex128) else float32

def fit(
self, train_dataloader: DataLoader | None = None, val_dataloader: DataLoader | None = None
self,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
) -> tuple[nn.Module, optim.Optimizer]:
"""
Fits the model using the specified training configuration.
Expand All @@ -294,8 +296,8 @@ def fit(
provided in the trainer will be used.

Args:
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.

Returns:
tuple[nn.Module, optim.Optimizer]: The trained model and optimizer.
Expand Down
4 changes: 2 additions & 2 deletions tests/ml_tools/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
trainer.fit()
assert (
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
"must be an instance of `DataLoader`." in exc_info.exconly()
"must be an instance of `DataLoader` or `DictDataLoader`." in exc_info.exconly()
)


Expand Down Expand Up @@ -274,7 +274,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
trainer.fit()
assert (
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
"must be an instance of `DataLoader`." in exc_info.exconly()
"must be an instance of `DataLoader` or `DictDataLoader`." in exc_info.exconly()
)


Expand Down
Loading