Skip to content

Commit

Permalink
Extend IndexType to explicitly have list[int]:
Browse files Browse the repository at this point in the history
  * mypy can't infer that this type is also permitted and so it complains that Collector._reset_state's id is not compatible when used inside the function (e.g. state.empty_(id) when state is Batch)
  • Loading branch information
dantp-ai committed Mar 27, 2024
1 parent 75bb1f0 commit a0d1427
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
import torch

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
IndexType = (
np.ndarray
| _SingleIndexType
| list[_SingleIndexType]
| tuple[_SingleIndexType, ...]
| list[int]
)
TBatch = TypeVar("TBatch", bound="BatchProtocol")
arr_type = torch.Tensor | np.ndarray

Expand Down

0 comments on commit a0d1427

Please sign in to comment.