From a0d1427b0170af8aef29d8c6a1ff93cbc3876fb5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:52:47 +0100 Subject: [PATCH] Extend IndexType to explicitly have list[int]: * mypy can't infer that this type is also permitted and so it complains that Collector._reset_state's id is not compatible when used inside the function (e.g. state.empty_(id) when state is Batch) --- tianshou/data/batch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 508e5c9a2..7a6bb0045 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -19,7 +19,13 @@ import torch _SingleIndexType = slice | int | EllipsisType -IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] +IndexType = ( + np.ndarray + | _SingleIndexType + | list[_SingleIndexType] + | tuple[_SingleIndexType, ...] + | list[int] +) TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray