From 26db9f6af701fd492a565854082d2fec6e23c71a Mon Sep 17 00:00:00 2001 From: aaron-sandoval <32021231+aaron-sandoval@users.noreply.github.com> Date: Fri, 31 May 2024 23:32:44 -0600 Subject: [PATCH] `test_edge_subsets` passing --- maze_dataset/tokenization/maze_tokenizer.py | 7 +-- maze_dataset/utils.py | 52 ++++++++++++++++++- .../tokenization/test_token_utils.py | 3 +- .../tokenization/test_tokenizer.py | 19 +++---- 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/maze_dataset/tokenization/maze_tokenizer.py b/maze_dataset/tokenization/maze_tokenizer.py index cf857883..b5fd6934 100644 --- a/maze_dataset/tokenization/maze_tokenizer.py +++ b/maze_dataset/tokenization/maze_tokenizer.py @@ -17,7 +17,7 @@ Literal, TypedDict, ) -from jaxtyping import Int64, Int8 +from jaxtyping import Int64, Int8, Int import numpy as np from muutils.json_serialize import ( @@ -709,9 +709,10 @@ def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": intra=self.intra ) - def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: + def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function - sorted_edges: ConnectionArray = np.lexsort((edges[:,1,1], edges[:,1,0], edges[:,0,1], edges[:,0,0])) + index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort((edges[:,1,1], edges[:,1,0], edges[:,0,1], edges[:,0,0])) + sorted_edges: ConnectionArray = edges[index_array,...] groups: list[ConnectionArray] = np.split(sorted_edges, np.unique(sorted_edges[:,0,:], return_index=True, axis=0)[1][1:]) if self.shuffle_group: [numpy_rng.shuffle(g, axis=0) for g in groups] diff --git a/maze_dataset/utils.py b/maze_dataset/utils.py index 4f2e27fe..915deff1 100644 --- a/maze_dataset/utils.py +++ b/maze_dataset/utils.py @@ -39,7 +39,54 @@ class IsDataclass(Protocol): __dataclass_fields__: ClassVar[dict[str, Any]] -FiniteValued = TypeVar("FiniteValued", bool, IsDataclass, enum.Enum) +""" +# `FiniteValued` +The intended definition of this type is not possible to fully define via the Python 3.10 typing library. +This custom generic type is a generic domain of many types which have a finite, discrete range space. +It was created to define the domain of types for the `all_instances` function, since this function relies heavily on static typing. +It may later be used in related applications. +These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). +The leaves of the tree must always be Primitive Types. + +# `FiniteValued` Subtypes +*: Indicates that this subtype is not yet supported by `all_instances` + +## Non-`FiniteValued` (Unbounded) Types +These are NOT valid subtypes, but are listed for illustrative purposes. +This list is not comprehensive. +While the finite nature of digital computers means that the cardinality of these types is technically finite, +they are considered unbounded types in this context. +- No container subtype may contain any of these unbounded subtypes. +- `int` +- `float` +- `str` +- `list` +- `set`: Set types without a fixed length are unbounded +- `tuple`: Tuple types without a fixed length are unbounded + +## Primitive Types +Primitive types are non-nested types which resolve directly to a concrete range of values +- `bool`: has 2 possible values +- `enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members +- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. +This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. + +## Container Types +Container types are types which contain zero or more fields of `FiniteValued` type. +The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. +- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. +- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. +- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. +- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. + +## Superclass Types +Superclass types don't directly contain data members like container types. +Their range is the union of the ranges of their subtypes. +- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types +- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types +- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] +""" +FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) def bool_array_from_string( @@ -444,6 +491,9 @@ def all_instances( elif get_origin(type_) == UnionType: # Union: call `all_instances` for each type in the Union return _apply_validation_func(type_, list(flatten([all_instances(sub, validation_funcs) for sub in get_args(type_)], levels_to_flatten=1)), validation_funcs) + elif get_origin(type_) is Literal: + # Literal: return all Literal arguments + return _apply_validation_func(type_, list(get_args(type_)), validation_funcs) elif type(type_) == enum.EnumMeta: # `issubclass(type_, enum.Enum)` doesn't work # Enum: return all Enum members raise NotImplementedError(f"Support for Enums not yet implemented.") diff --git a/tests/unit/maze_dataset/tokenization/test_token_utils.py b/tests/unit/maze_dataset/tokenization/test_token_utils.py index 03c2a605..eba9b446 100644 --- a/tests/unit/maze_dataset/tokenization/test_token_utils.py +++ b/tests/unit/maze_dataset/tokenization/test_token_utils.py @@ -1,4 +1,4 @@ -from typing import Iterable, TypeVar, Callable +from typing import Iterable, TypeVar, Callable, Literal from dataclasses import dataclass import pytest @@ -628,6 +628,7 @@ def foo(): pass (bool, [True, False]), (int, TypeError), (str, TypeError), + (Literal[0, 1, 2], [0, 1, 2]), (tuple[bool], [ (True,), diff --git a/tests/unit/maze_dataset/tokenization/test_tokenizer.py b/tests/unit/maze_dataset/tokenization/test_tokenizer.py index d281f6a5..af2932a7 100644 --- a/tests/unit/maze_dataset/tokenization/test_tokenizer.py +++ b/tests/unit/maze_dataset/tokenization/test_tokenizer.py @@ -781,14 +781,14 @@ def test_edge_subsets(es: EdgeSubsets.EdgeSubset, maze: LatticeMaze): @mark.parametrize( "tok_elem,es,maze", [ - param(tok_elem, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]") + param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]") for (i, maze), tok_elem, es in itertools.product( enumerate(MIXED_MAZES[:6]), all_instances( EdgeGroupings.EdgeGrouping, frozendict.frozendict({ TokenizerElement: lambda x: x.is_valid(), - # Add a condition to trim out the param space that doesn't affect functionality being tested + # Add a condition to prune the range space that doesn't affect functionality being tested EdgeGroupings.ByLeadingCoord: lambda x: x.intra and x.connection_token_ordinal==1 }) ), @@ -810,15 +810,16 @@ def test_edge_subsets(tok_elem: EdgeGroupings.EdgeGrouping, es: EdgeSubsets.Edge case EdgeGroupings.ByLeadingCoord: assert len(groups) == np.unique(edges[:,0,:], axis=0).shape[0] assert sum(g.shape[0] for g in groups) == edges.shape[0] + trailing_coords: list[CoordArray] = [g[:,1,:] for g in groups] + # vector_diffs is the position vector difference between the trailing coords of each group + # These are stacked into a single array since we don't care about maintaining group separation + vector_diffs: CoordArray = np.stack(list(flatten([np.diff(g[:,1,:], axis=0) for g in groups], 1))) if tok_elem.shuffle_group: - ... + allowed_diffs = {(1,-1),(1,1),(0,2),(2,0)} + # The set of all 2D vectors between any 2 coords adjacent to a central coord + allowed_diffs = allowed_diffs.union({(-d[0], -d[1]) for d in allowed_diffs}) else: - trailing_coords: list[CoordArray] = [g[:,1,:] for g in groups] - # vector_diffs is the position vector difference between the trailing coords of each group - # These are stacked into a single array since we don't care about maintaining group separation - vector_diffs: CoordArray = np.stack(list(flatten([np.diff(g[:,1,:], axis=0) for g in groups], 1))) # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting allowed_diffs = {(1,-1),(1,1),(0,2),(2,0)} - # vector_diffs are - assert all(tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0)) + assert all(tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0)) \ No newline at end of file