Skip to content

Commit

Permalink
test_edge_subsets passing
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-sandoval committed Jun 1, 2024
1 parent 64dd2ff commit 26db9f6
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 14 deletions.
7 changes: 4 additions & 3 deletions maze_dataset/tokenization/maze_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Literal,
TypedDict,
)
from jaxtyping import Int64, Int8
from jaxtyping import Int64, Int8, Int

import numpy as np
from muutils.json_serialize import (
Expand Down Expand Up @@ -709,9 +709,10 @@ def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
intra=self.intra
)

def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
sorted_edges: ConnectionArray = np.lexsort((edges[:,1,1], edges[:,1,0], edges[:,0,1], edges[:,0,0]))
index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort((edges[:,1,1], edges[:,1,0], edges[:,0,1], edges[:,0,0]))
sorted_edges: ConnectionArray = edges[index_array,...]
groups: list[ConnectionArray] = np.split(sorted_edges, np.unique(sorted_edges[:,0,:], return_index=True, axis=0)[1][1:])
if self.shuffle_group:
[numpy_rng.shuffle(g, axis=0) for g in groups]
Expand Down
52 changes: 51 additions & 1 deletion maze_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,54 @@ class IsDataclass(Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]


FiniteValued = TypeVar("FiniteValued", bool, IsDataclass, enum.Enum)
"""
# `FiniteValued`
The intended definition of this type is not possible to fully define via the Python 3.10 typing library.
This custom generic type is a generic domain of many types which have a finite, discrete range space.
It was created to define the domain of types for the `all_instances` function, since this function relies heavily on static typing.
It may later be used in related applications.
These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
The leaves of the tree must always be Primitive Types.
# `FiniteValued` Subtypes
*: Indicates that this subtype is not yet supported by `all_instances`
## Non-`FiniteValued` (Unbounded) Types
These are NOT valid subtypes, but are listed for illustrative purposes.
This list is not comprehensive.
While the finite nature of digital computers means that the cardinality of these types is technically finite,
they are considered unbounded types in this context.
- No container subtype may contain any of these unbounded subtypes.
- `int`
- `float`
- `str`
- `list`
- `set`: Set types without a fixed length are unbounded
- `tuple`: Tuple types without a fixed length are unbounded
## Primitive Types
Primitive types are non-nested types which resolve directly to a concrete range of values
- `bool`: has 2 possible values
- `enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members
- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition.
This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy.
## Container Types
Container types are types which contain zero or more fields of `FiniteValued` type.
The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`.
- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`.
- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`.
- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed.
- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type.
## Superclass Types
Superclass types don't directly contain data members like container types.
Their range is the union of the ranges of their subtypes.
- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3]
"""
FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum)


def bool_array_from_string(
Expand Down Expand Up @@ -444,6 +491,9 @@ def all_instances(
elif get_origin(type_) == UnionType:
# Union: call `all_instances` for each type in the Union
return _apply_validation_func(type_, list(flatten([all_instances(sub, validation_funcs) for sub in get_args(type_)], levels_to_flatten=1)), validation_funcs)
elif get_origin(type_) is Literal:
# Literal: return all Literal arguments
return _apply_validation_func(type_, list(get_args(type_)), validation_funcs)
elif type(type_) == enum.EnumMeta: # `issubclass(type_, enum.Enum)` doesn't work
# Enum: return all Enum members
raise NotImplementedError(f"Support for Enums not yet implemented.")
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/maze_dataset/tokenization/test_token_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, TypeVar, Callable
from typing import Iterable, TypeVar, Callable, Literal
from dataclasses import dataclass

import pytest
Expand Down Expand Up @@ -628,6 +628,7 @@ def foo(): pass
(bool, [True, False]),
(int, TypeError),
(str, TypeError),
(Literal[0, 1, 2], [0, 1, 2]),
(tuple[bool],
[
(True,),
Expand Down
19 changes: 10 additions & 9 deletions tests/unit/maze_dataset/tokenization/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,14 +781,14 @@ def test_edge_subsets(es: EdgeSubsets.EdgeSubset, maze: LatticeMaze):
@mark.parametrize(
"tok_elem,es,maze",
[
param(tok_elem, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
for (i, maze), tok_elem, es in itertools.product(
enumerate(MIXED_MAZES[:6]),
all_instances(
EdgeGroupings.EdgeGrouping,
frozendict.frozendict({
TokenizerElement: lambda x: x.is_valid(),
# Add a condition to trim out the param space that doesn't affect functionality being tested
# Add a condition to prune the range space that doesn't affect functionality being tested
EdgeGroupings.ByLeadingCoord: lambda x: x.intra and x.connection_token_ordinal==1
})
),
Expand All @@ -810,15 +810,16 @@ def test_edge_subsets(tok_elem: EdgeGroupings.EdgeGrouping, es: EdgeSubsets.Edge
case EdgeGroupings.ByLeadingCoord:
assert len(groups) == np.unique(edges[:,0,:], axis=0).shape[0]
assert sum(g.shape[0] for g in groups) == edges.shape[0]
trailing_coords: list[CoordArray] = [g[:,1,:] for g in groups]
# vector_diffs is the position vector difference between the trailing coords of each group
# These are stacked into a single array since we don't care about maintaining group separation
vector_diffs: CoordArray = np.stack(list(flatten([np.diff(g[:,1,:], axis=0) for g in groups], 1)))
if tok_elem.shuffle_group:
...
allowed_diffs = {(1,-1),(1,1),(0,2),(2,0)}
# The set of all 2D vectors between any 2 coords adjacent to a central coord
allowed_diffs = allowed_diffs.union({(-d[0], -d[1]) for d in allowed_diffs})
else:
trailing_coords: list[CoordArray] = [g[:,1,:] for g in groups]
# vector_diffs is the position vector difference between the trailing coords of each group
# These are stacked into a single array since we don't care about maintaining group separation
vector_diffs: CoordArray = np.stack(list(flatten([np.diff(g[:,1,:], axis=0) for g in groups], 1)))
# If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting
allowed_diffs = {(1,-1),(1,1),(0,2),(2,0)}
# vector_diffs are
assert all(tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0))
assert all(tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0))

0 comments on commit 26db9f6

Please sign in to comment.