You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
baseline_predictions = Baseline().predict(val_dataloader)
SMAPE()(baseline_predictions, actuals)
expected to get result without error
Actual behavior
However, errors came up:
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:446](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:446), in Metric._wrap_update..wrapped_func(*args, **kwargs)
445 try:
--> 446 update(*args, **kwargs)
447 except RuntimeError as err:
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/base_metrics.py:784](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/base_metrics.py:784), in MultiHorizonMetric.update(self, y_pred, target)
782 lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device)
--> 784 losses = self.loss(y_pred, target)
785 # weight samples
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/point.py:69](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/point.py:69), in SMAPE.loss(self, y_pred, target)
68 y_pred = self.to_prediction(y_pred)
---> 69 loss = 2 * (y_pred - target).abs() [/](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/) (y_pred.abs() + target.abs() + 1e-8)
70 return loss
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[10], line 5
3 actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
4 baseline_predictions = Baseline().predict(val_dataloader)
----> 5 SMAPE()(baseline_predictions, actuals)
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:290](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:290), in Metric.forward(self, *args, **kwargs)
288 self._forward_cache = self._forward_full_state_update(*args, **kwargs)
289 else:
--> 290 self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
292 return self._forward_cache
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:357](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:357), in Metric._forward_reduce_state_update(self, *args, **kwargs)
354 self._enable_grad = True # allow grads for batch computation
356 # calculate batch state and compute batch value
--> 357 self.update(*args, **kwargs)
358 batch_val = self.compute()
360 # reduce batch and global state
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:449](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:449), in Metric._wrap_update..wrapped_func(*args, **kwargs)
447 except RuntimeError as err:
448 if "Expected all tensors to be on" in str(err):
--> 449 raise RuntimeError(
450 "Encountered different devices in metric calculation (see stacktrace for details)."
451 " This could be due to the metric class not being on the same device as input."
452 f" Instead of `metric={self.__class__.__name__}(...)` try to do"
453 f" `metric={self.__class__.__name__}(...).to(device)` where"
454 " device corresponds to the device of the input."
455 ) from err
456 raise err
458 if self.compute_on_cpu:
RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=SMAPE(...)` try to do `metric=SMAPE(...).to(device)` where device corresponds to the device of the input.
The text was updated successfully, but these errors were encountered:
Expected behavior
The code from N-beats tutorial:
expected to get result without error
Actual behavior
However, errors came up:
The text was updated successfully, but these errors were encountered: