Skip to content

Commit

Permalink
Typing in tests: added asserts and cast to remove some mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Mar 21, 2024
1 parent c6aa77b commit 0de06a4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 5 additions & 0 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import sys
from itertools import starmap
from typing import cast

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -160,7 +161,11 @@ def test_batch() -> None:
batch5 = Batch(a=np.array([{"index": 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
# We use setattr b/c the setattr of Batch will actually change the type of the field that is being set!
# However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to
# batch.b being an array (which it is not, it's turned into a Batch instead)
batch5.b = np.array([{"index": 1}])
batch5.b = cast(Batch, batch5.b)
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])

Expand Down
6 changes: 4 additions & 2 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,11 @@ def test_collector_with_dict_state() -> None:
batch, _ = c1.buffer.sample(10)
c0.buffer.update(c1.buffer)
assert len(c0.buffer) in [42, 43]
cur_obs = c0.buffer[:].obs
assert isinstance(cur_obs, Batch)
if len(c0.buffer) == 42:
assert np.all(
c0.buffer[:].obs.index[..., 0]
cur_obs.index[..., 0]
== [
0,
1,
Expand Down Expand Up @@ -364,7 +366,7 @@ def test_collector_with_dict_state() -> None:
), c0.buffer[:].obs.index[..., 0]
else:
assert np.all(
c0.buffer[:].obs.index[..., 0]
cur_obs.index[..., 0]
== [
0,
1,
Expand Down

0 comments on commit 0de06a4

Please sign in to comment.