Skip to content

Commit

Permalink
wip test_edge_subsets
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-sandoval committed Jun 1, 2024
1 parent 973df5a commit 64dd2ff
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
16 changes: 14 additions & 2 deletions maze_dataset/tokenization/maze_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
"""Divides a ConnectionArray into groups of edges.
Shuffles/sequences within each group if applicable.
"""
...
pass

@abc.abstractmethod
def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
Expand Down Expand Up @@ -683,6 +683,8 @@ def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
intra=False
)

def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
return np.expand_dims(edges, 1)

@serializable_dataclass(frozen=True, kw_only=True)
class ByLeadingCoord(EdgeGrouping):
Expand All @@ -692,7 +694,8 @@ class ByLeadingCoord(EdgeGrouping):
- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
If false, the fixed order is NORTH, WEST, SOUTH, EAST, where the directions indicate the position of the connecting coord relative to the leading coord.
If false, the fixed order is lexicographical by (row, col).
In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
"""
Expand All @@ -706,6 +709,15 @@ def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
intra=self.intra
)

def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
sorted_edges: ConnectionArray = np.lexsort((edges[:,1,1], edges[:,1,0], edges[:,0,1], edges[:,0,0]))
groups: list[ConnectionArray] = np.split(sorted_edges, np.unique(sorted_edges[:,0,:], return_index=True, axis=0)[1][1:])
if self.shuffle_group:
[numpy_rng.shuffle(g, axis=0) for g in groups]
return groups


class EdgePermuters(_TokenizerElementNamespace):
"""Namespace for `EdgePermuter` subclass hierarchy used by `AdjListTokenizer`.
"""
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/maze_dataset/tokenization/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
VOCAB_LIST,
Coord,
CoordTup,
CoordArray,
ConnectionArray,
LatticeMaze,
LatticeMazeGenerators,
Expand Down Expand Up @@ -45,7 +46,7 @@
TargetTokenizers,
TokenizationMode,
)
from maze_dataset.utils import all_instances, manhattan_distance
from maze_dataset.utils import all_instances, manhattan_distance, flatten
from maze_dataset.tokenization.maze_tokenizer import _load_tokenizer_hashes
from maze_dataset.util import equal_except_adj_list_sequence, connection_list_to_adj_list
from maze_dataset.token_utils import get_path_tokens
Expand Down Expand Up @@ -773,4 +774,51 @@ def test_edge_subsets(es: EdgeSubsets.EdgeSubset, maze: LatticeMaze):
assert edges.dtype == np.int8
assert assert_shape == tuple(edges.shape)
assert assert_shape == tuple(np.unique(edges, axis=0).shape) # All edges are unique (swapping leading/trailing coords is considered different)
assert np.array_equal(manhattan_distance(edges), np.array([1]*assert_shape[0], dtype=np.int8))
assert np.array_equal(manhattan_distance(edges), np.array([1]*assert_shape[0], dtype=np.int8))



@mark.parametrize(
"tok_elem,es,maze",
[
param(tok_elem, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
for (i, maze), tok_elem, es in itertools.product(
enumerate(MIXED_MAZES[:6]),
all_instances(
EdgeGroupings.EdgeGrouping,
frozendict.frozendict({
TokenizerElement: lambda x: x.is_valid(),
# Add a condition to trim out the param space that doesn't affect functionality being tested
EdgeGroupings.ByLeadingCoord: lambda x: x.intra and x.connection_token_ordinal==1
})
),
all_instances(
EdgeSubsets.EdgeSubset,
frozendict.frozendict({TokenizerElement: lambda x: x.is_valid()})
),
)
],
)
def test_edge_subsets(tok_elem: EdgeGroupings.EdgeGrouping, es: EdgeSubsets.EdgeSubset, maze: LatticeMaze):
edges: ConnectionArray = es._get_edges(maze)
n: int = maze.grid_n
groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges)
match type(tok_elem):
case EdgeGroupings.Ungrouped:
assert_shape = edges.shape[0], 1, 2, 2
assert tuple(groups.shape) == assert_shape
case EdgeGroupings.ByLeadingCoord:
assert len(groups) == np.unique(edges[:,0,:], axis=0).shape[0]
assert sum(g.shape[0] for g in groups) == edges.shape[0]
if tok_elem.shuffle_group:
...
else:
trailing_coords: list[CoordArray] = [g[:,1,:] for g in groups]
# vector_diffs is the position vector difference between the trailing coords of each group
# These are stacked into a single array since we don't care about maintaining group separation
vector_diffs: CoordArray = np.stack(list(flatten([np.diff(g[:,1,:], axis=0) for g in groups], 1)))
# If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting
allowed_diffs = {(1,-1),(1,1),(0,2),(2,0)}
# vector_diffs are
assert all(tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0))

0 comments on commit 64dd2ff

Please sign in to comment.