diff --git a/docs/requirements.txt b/docs/requirements.txt index e212cd942f4..90efea35854 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -28,3 +28,7 @@ vmas onnxscript onnxruntime onnx +plotly +igraph +transformers +datasets diff --git a/docs/source/_static/img/rollout-llm.png b/docs/source/_static/img/rollout-llm.png new file mode 100644 index 00000000000..b2e63394de1 Binary files /dev/null and b/docs/source/_static/img/rollout-llm.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 2eedc045416..6a448d61c41 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -105,6 +105,7 @@ Intermediate tutorials/dqn_with_rnn tutorials/rb_tutorial tutorials/export + tutorials/beam_search_with_gpt Advanced -------- diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 4e943e03cfc..77fbff865be 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1929,14 +1929,18 @@ def __init__(self): tensor=Unbounded(3), non_tensor=NonTensor(shape=()), ) + self._saved_obs_spec = self.observation_spec.clone() self.state_spec = Composite( non_tensor=NonTensor(shape=()), ) + self._saved_state_spec = self.state_spec.clone() self.reward_spec = Unbounded(1) + self._saved_full_reward_spec = self.full_reward_spec.clone() self.action_spec = Unbounded(1) + self._saved_full_action_spec = self.full_action_spec.clone() def _reset(self, tensordict): - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", 0) data.update(self.full_done_spec.zero()) return data @@ -1945,10 +1949,10 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1) data.update(self.full_done_spec.zero()) - data.update(self.full_reward_spec.zero()) + data.update(self._saved_full_reward_spec.zero()) return data def _set_seed(self, seed: Optional[int]): diff --git a/test/test_env.py b/test/test_env.py index b48b1a1cf8f..781f7fe71c0 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3528,8 +3528,13 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) -def test_auto_spec(): - env = CountingEnv() +@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata]) +def test_auto_spec(env_type): + if env_type is EnvWithMetadata: + obs_vals = ["tensor", "non_tensor"] + else: + obs_vals = "observation" + env = env_type() td = env.reset() policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( @@ -3552,7 +3557,7 @@ def test_auto_spec(): shape=env.full_state_spec.shape, device=env.full_state_spec.device ) env._action_keys = ["action"] - env.auto_specs_(policy, tensordict=td.copy()) + env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals) env.check_env_specs(tensordict=td.copy()) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d37aebb862f..c81ffcc962b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -829,6 +829,7 @@ def _can_be_pickled(obj): def _make_ordinal_device(device: torch.device): if device is None: return device + device = torch.device(device) if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) if device.type == "mps" and device.index is None: diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 01988dc43be..59526628dbe 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class SipHash(Module): """A Module to Compute SipHash values for given tensors. - A hash function module based on SipHash implementation in python. + A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` + and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index a601f1e3261..6ff17daaed5 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -177,7 +177,7 @@ def from_tensordict_pair( collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, consolidated: bool | None = None, - ): + ) -> TensorDictMap: """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. Args: @@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): if not self._has_lazy_out_keys(): # TODO: make this work with pytrees and avoid calling select if keys match value = value.select(*self.out_keys, strict=False) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + if index.unique().numel() < index.numel(): + # If multiple values point to the same place in the storage, we cannot process them by batch + # There could be a better way to deal with this, using unique ids. + vals = [] + for it, val in zip(item.split(1), value.split(1)): + self[it] = val + vals.append(val) + # __setitem__ may affect the content of the input data + value.update(TensorDictBase.lazy_stack(vals)) + return if self.write_fn is not None: + # We use this block in the following context: the value written in the storage is already present, + # but it needs to be updated. + # We first check if the value is already there using `contains`. If so, we pass the new value and the + # previous one to write_fn. The values that are not present are passed alone. if len(self): modifiable = self.contains(item) if modifiable.any(): @@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): value = self.write_fn(value) else: value = self.write_fn(value) - item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) self.storage.set(index, value) def __len__(self): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..b88e6a4a2ec 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -15,6 +15,7 @@ TensorClass, TensorDict, TensorDictBase, + unravel_key, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree @@ -94,7 +95,7 @@ def num_children(self) -> int: @property def is_terminal(self): - """Returns True if the the tree has no children nodes.""" + """Returns True if the tree has no children nodes.""" return self.subtree is None def get_vertex_by_id(self, id: int) -> Tree: @@ -163,9 +164,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: @@ -206,6 +204,14 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: ) def edges(self) -> List[Tuple[int, int]]: + """Retrieves a list of edges in the tree. + + Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. + The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited. + + Returns: + A list of tuples, where each tuple contains a parent node ID and a child node ID. + """ result = [] q = deque() parent = self.node_id @@ -221,22 +227,62 @@ def edges(self) -> List[Tuple[int, int]]: return result def valid_paths(self): + """Generates all valid paths in the tree. + + A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. + Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node. + + Yields: + tuple: A valid path in the tree. + """ + # Initialize a queue with the current tree node and an empty path q = deque() cur_path = () q.append((self, cur_path)) + # Perform BFS traversal of the tree while len(q): + # Dequeue the next tree node and its current path tree, cur_path = q.popleft() + # Get the number of child nodes n = int(tree.num_children) + # If this is a leaf node, yield the current path if not n: yield cur_path + # Iterate over the child nodes for i in range(n): cur_path_tree = cur_path + (i,) q.append((tree.subtree[i], cur_path_tree)) def max_length(self): - return max(*(len(path) for path in self.valid_paths())) + """Returns the maximum length of all valid paths in the tree. + + The length of a path is defined as the number of nodes in the path. + If the tree is empty, returns 0. + + Returns: + int: The maximum length of all valid paths in the tree. + + """ + lengths = tuple(len(path) for path in self.valid_paths()) + if len(lengths) == 0: + return 0 + elif len(lengths) == 1: + return lengths[0] + return max(*lengths) def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + """Retrieves the rollout data along a given path in the tree. + + The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. + If no rollout data is found along the path, returns ``None``. + + Args: + path: A tuple of integers representing the path in the tree. + + Returns: + The concatenated rollout data along the path, or None if no data is found. + + """ r = self.rollout tree = self rollouts = [] @@ -272,8 +318,19 @@ def plot( backend: str = "plotly", figure: str = "tree", info: List[str] = None, - make_labels: Callable[[Any], Any] | None = None, + make_labels: Callable[[Any, ...], Any] | None = None, ): + """Plots a visualization of the tree using the specified backend and figure type. + + Args: + backend: The plotting backend to use. Currently only supports 'plotly'. + figure: The type of figure to plot. Can be either 'tree' or 'box'. + info: A list of additional information to include in the plot (not currently used). + make_labels: An optional function to generate custom labels for the plot. + + Raises: + NotImplementedError: If an unsupported backend or figure type is specified. + """ if backend == "plotly": if figure == "box": _plot_plotly_box(self) @@ -284,7 +341,7 @@ def plot( else: pass raise NotImplementedError( - f"Unkown plotting backend {backend} with figure {figure}." + f"Unknown plotting backend {backend} with figure {figure}." ) @@ -423,47 +480,99 @@ def __init__( self.consolidated = consolidated @property - def done_keys(self): + def done_keys(self) -> List[NestedKey]: + """Done Keys. + + Returns the keys used to indicate that an episode has ended. + The default done keys are "done", "terminated", and "truncated". These keys can be + used in the environment's output to signal the end of an episode. + + Returns: + A list of strings representing the done keys. + + """ done_keys = getattr(self, "_done_keys", None) if done_keys is None: - self._done_keys = done_keys = ("done", "terminated", "truncated") + self._done_keys = done_keys = ["done", "terminated", "truncated"] return done_keys @done_keys.setter def done_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._done_keys = value @property - def reward_keys(self): + def reward_keys(self) -> List[NestedKey]: + """Reward Keys. + + Returns the keys used to retrieve rewards from the environment's output. + The default reward key is "reward". + + Returns: + A list of strings or tuples representing the reward keys. + + """ reward_keys = getattr(self, "_reward_keys", None) if reward_keys is None: - self._reward_keys = reward_keys = ("reward",) + self._reward_keys = reward_keys = ["reward"] return reward_keys @reward_keys.setter def reward_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._reward_keys = value @property - def action_keys(self): + def action_keys(self) -> List[NestedKey]: + """Action Keys. + + Returns the keys used to retrieve actions from the environment's input. + The default action key is "action". + + Returns: + A list of strings or tuples representing the action keys. + + """ action_keys = getattr(self, "_action_keys", None) if action_keys is None: - self._action_keys = action_keys = ("action",) + self._action_keys = action_keys = ["action"] return action_keys @action_keys.setter def action_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._action_keys = value @property - def observation_keys(self): + def observation_keys(self) -> List[NestedKey]: + """Observation Keys. + + Returns the keys used to retrieve observations from the environment's output. + The default observation key is "observation". + + Returns: + A list of strings or tuples representing the observation keys. + """ observation_keys = getattr(self, "_observation_keys", None) if observation_keys is None: - self._observation_keys = observation_keys = ("observation",) + self._observation_keys = observation_keys = ["observation"] return observation_keys @observation_keys.setter def observation_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._observation_keys = value def get_keys_from_env(self, env: EnvBase): @@ -482,8 +591,21 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): + # This function updates the old values by adding the new ones + # if and only if the new ones are not there. + # If the old value is not provided, we assume there are none and the + # `new` is just prepared. + # This involves unsqueezing the last dim (since we'll be stacking tensors + # and calling unique). + # The update involves calling cat along the last dim + unique + # which will keep only the new values that were unknown to + # the storage. + # We use this method to track all the indices that are associated with + # an observation. Every time a new index is obtained, it is stacked alongside + # the others. if old is None: - result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + # we unsqueeze the values to stack them along dim -1 + result = new.apply(lambda x: x.unsqueeze(-1), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -493,8 +615,15 @@ def cat(name, x, y): if name == "count": return x if y.ndim < x.ndim: - y = y.unsqueeze(0) - result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + y = y.unsqueeze(-1) + result = torch.cat([x, y], -1) + # Breaks on mps + if result.device.type == "mps": + result = result.cpu() + result = result.unique(dim=-1, sorted=False) + result = result.to("mps") + else: + result = result.unique(dim=-1, sorted=False) return result result = old.named_apply(cat, new, default=None) @@ -543,12 +672,35 @@ def extend(self, rollout): # # Set the action in the 'next' # dest[1:] = source[:-1].exclude(*self.done_keys) + # Add ('observation', 'action') -> ('next, observation') self.data_map[source] = dest value = source if self.node_map is None: self._make_storage_branches(source, dest) + # map ('observation',) -> ('indices',) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + def add(self, step): + source, dest = ( + step.exclude("next").copy(), + step.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_storage(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + # Add ('observation', 'action') -> ('next, observation') + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_storage_branches(source, dest) + # map ('observation',) -> ('indices',) + self.node_map[source] = value + def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] @@ -582,6 +734,14 @@ def _make_local_tree( if not compact: break else: + # If the root is provided and not gathered from the storage, it could be that its + # device doesn't match the data_map storage device. + device = getattr(self.data_map.storage, "device", None) + if root.device != device: + if device is not None: + root = root.to(self.data_map.storage.device) + else: + root.clear_device_() index = None break rollout = None diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 570214f1cb2..d9588d79905 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -17,13 +17,13 @@ def _plot_plotly_tree( if make_labels is None: - def make_labels(tree): + def make_labels(tree, path, *args, **kwargs): return str((tree.node_id, tree.hash)) nr_vertices = tree.num_vertices() - vertices = tree.vertices() + vertices = tree.vertices(key_type="path") - v_label = [make_labels(subtree) for subtree in vertices.values()] + v_label = [make_labels(subtree, path) for path, subtree in vertices.items()] G = Graph(nr_vertices, tree.edges()) layout = G.layout_sugiyama(range(nr_vertices)) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 665cae254f5..ae0d97b7bab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -246,8 +246,8 @@ def set( set_cursor: bool = True, ): if not isinstance(cursor, INT_CLASSES): - if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( - isinstance(cursor, np.ndarray) and cursor.size <= 1 + if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or ( + isinstance(cursor, np.ndarray) and cursor.ndim == 0 ): self.set(int(cursor), data, set_cursor=set_cursor) return diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ddf6ed41c99..dad0aaf69a6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -41,7 +41,7 @@ unravel_key, ) from tensordict.base import NO_DEFAULT -from tensordict.utils import _getitem_batch_size, NestedKey +from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for DEVICE_TYPING = Union[torch.device, str, int] @@ -2466,10 +2466,10 @@ def one(self, shape=None): data=None, batch_size=(*shape, *self._safe_shape), device=self.device ) - def is_in(self, val: torch.Tensor) -> bool: + def is_in(self, val: Any) -> bool: shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( - isinstance(val, NonTensorData) + is_non_tensor(val) and val.shape == shape # We relax constrains on device as they're hard to enforce for non-tensor # tensordicts and pointless @@ -4373,7 +4373,7 @@ def set(self, name, spec): shape = spec.shape if shape[: self.ndim] != self.shape: if ( - isinstance(spec, Composite) + isinstance(spec, (Composite, NonTensor)) and spec.ndim < self.ndim and self.shape[: spec.ndim] == spec.shape ): @@ -4382,7 +4382,7 @@ def set(self, name, spec): spec.shape = self.shape else: raise ValueError( - "The shape of the spec and the Composite mismatch: the first " + f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 36e4ec1a908..f3dec221ce0 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import PendulumEnv, TicTacToeEnv +from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index bafe88b639a..13d02af2b36 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,14 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.utils import NestedKey +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) +from tensordict.base import _is_leaf_nontensor +from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, _make_ordinal_device, @@ -25,7 +31,13 @@ seed_generator, ) -from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import ( + Categorical, + Composite, + NonTensor, + TensorSpec, + Unbounded, +) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -430,7 +442,6 @@ def auto_specs_( done_key: NestedKey | List[NestedKey] | None = None, observation_key: NestedKey | List[NestedKey] = "observation", reward_key: NestedKey | List[NestedKey] = "reward", - batch_size: torch.Size | None = None, ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -484,6 +495,7 @@ def auto_specs_( tensordict2, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) input_spec = Composite(input_spec_stack, batch_size=batch_size) if not self.batch_locked and batch_size != self.batch_size: @@ -501,6 +513,7 @@ def auto_specs_( nexts_1, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) output_spec = Composite(output_spec_stack, batch_size=batch_size) @@ -523,7 +536,8 @@ def auto_specs_( full_observation_spec = output_spec.separates(*observation_key, default=None) if not output_spec.is_empty(recurse=True): raise RuntimeError( - f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + f"Keys {list(output_spec.keys(True, True))} are unaccounted for. " + f"Make sure you have passed all the leaf names to the auto_specs_ method." ) if full_action_spec is not None: @@ -2999,6 +3013,52 @@ def add_truncated_keys(self) -> EnvBase: self.__dict__["_done_keys"] = None return self + def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Advances the environment state by one step using the provided `next_tensordict`. + + This method updates the environment's state by transitioning from the current + state to the next, as defined by the `next_tensordict`. The resulting tensordict + includes updated observations and any other relevant state information, with + keys managed according to the environment's specifications. + + Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently + handle the transition of state, observation, action, reward, and done keys. The + :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and + exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance + is created with `exclude_action=False`, meaning that action keys are retained in + the root tensordict. + + Args: + next_tensordict (TensorDictBase): A tensordict containing the state of the + environment at the next time step. This tensordict should include keys + for observations, actions, rewards, and done flags, as defined by the + environment's specifications. + + Returns: + TensorDictBase: A new tensordict representing the environment state after + advancing by one step. + + .. note:: The method ensures that the environment's key specifications are validated + against the provided `next_tensordict`, issuing warnings if discrepancies + are found. + + .. note:: This method is designed to work efficiently with environments that have + consistent key specifications, leveraging the `_StepMDP` class to minimize + overhead. + + Example: + >>> from torchrl.envs import GymEnv + >>> env = GymEnv("Pendulum-1") + >>> data = env.reset() + >>> for i in range(10): + ... # compute action + ... env.rand_action(data) + ... # Perform action + ... next_data = env.step(reset_data) + ... data = env.step_mdp(next_data) + """ + return self._step_mdp(next_tensordict) + @property def _step_mdp(self): step_func = self.__dict__.get("_step_mdp_value") @@ -3572,6 +3632,12 @@ def _has_dynamic_specs(spec: Composite): def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)): + stack[name] = NonTensor(shape=()) + return + elif is_non_tensor(leaf): + stack[name] = NonTensor(shape=leaf.shape) + return shape = leaf.shape if leaf_compare is not None: shape_compare = leaf_compare.shape diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 8649d3d3e97..375a0e23a57 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -3,5 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py new file mode 100644 index 00000000000..0413671f32c --- /dev/null +++ b/torchrl/envs/custom/llm.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable, List, Union + +import torch +from tensordict import NestedKey, TensorDictBase +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from torchrl.data import ( + Categorical as CategoricalSpec, + Composite, + NonTensor, + SipHash, + Unbounded, +) +from torchrl.envs import EnvBase +from torchrl.envs.utils import _StepMDP + + +class LLMHashingEnv(EnvBase): + """A text generation environment that uses a hashing module to identify unique observations. + + Args: + vocab_size (int): The size of the vocabulary. + hashing_module (Callable[[torch.Tensor], torch.Tensor], optional): + A hashing function that takes a tensor as input and returns a hashed tensor. + Defaults to :class:`~torchrl.data.SipHash` if not provided. + observation_key (NestedKey, optional): The key for the observation in the TensorDict. + Defaults to "observation". + text_output (bool, optional): Whether to include the text output in the observation. + Defaults to True. + tokenizer (transformers.Tokenizer | None, optional): + A tokenizer function that converts text to tensors. + Only used when `text_output` is `True`. + Must implement the following methods: `decode` and `batch_decode`. + Defaults to ``None``. + text_key (NestedKey | None, optional): The key for the text output in the TensorDict. + Defaults to "text". + + """ + + def __init__( + self, + vocab_size: int, + hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, + observation_key: NestedKey = "observation", + text_output: bool = True, + tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None, + text_key: NestedKey | None = "text", + ): + super().__init__() + self._batch_locked = False + if hashing_module is None: + hashing_module = SipHash() + + self._hashing_module = hashing_module + self._tokenizer = tokenizer + self.observation_key = observation_key + observation_spec = { + observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)), + "hash": Unbounded(shape=(1,), dtype=torch.int64), + } + self.text_output = text_output + if not text_output: + text_key = None + elif text_key is None: + text_key = "text" + if text_key is not None: + observation_spec[text_key] = NonTensor(shape=()) + self.text_key = text_key + self.observation_spec = Composite(observation_spec) + self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) + _StepMDP(self) + + def _reset(self, tensordict: TensorDictBase): + """Initializes the environment with a given observation. + + Args: + tensordict (TensorDictBase): A TensorDict containing the initial observation. + + Returns: + A TensorDict containing the initial observation, its hash, and other relevant information. + + """ + out = tensordict.empty() + obs = tensordict.get(self.observation_key) + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + out.set(self.text_key, text) + + if obs.ndim > 1: + out.set("hash", self._hashing_module(obs).unsqueeze(-1)) + else: + out.set("hash", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) + + if not self.full_done_spec.is_empty(): + out.update(self.full_done_spec.zero(tensordict.shape)) + else: + out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)) + out.set( + "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool) + ) + return out + + def _step(self, tensordict): + """Takes an action (i.e., the next token to generate) and returns the next observation and reward. + + Args: + tensordict: A TensorDict containing the current observation and action. + + Returns: + A TensorDict containing the next observation, its hash, and other relevant information. + """ + out = tensordict.empty() + action = tensordict.get("action") + obs = torch.cat([tensordict.get(self.observation_key), action], -1) + kwargs = {self.observation_key: obs} + + catval = torch.cat([tensordict.get("hash"), action], -1) + if obs.ndim > 1: + new_hash = self._hashing_module(catval).unsqueeze(-1) + else: + new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1) + + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + kwargs[self.text_key] = text + kwargs.update( + { + "hash": new_hash, + "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), + "terminated": torch.zeros( + (*tensordict.batch_size, 1), dtype=torch.bool + ), + } + ) + return out.update(kwargs) + + def _set_seed(self, *args): + """Sets the seed for the environment's randomness. + + .. note:: This environment has no randomness, so this method does nothing. + """ + pass diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 209349878ec..9cba14c9690 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import torch @@ -76,7 +76,7 @@ def __get__(self, cls, owner): class _StepMDP: - """Stateful version of step_mdp. + """Stateful version of :func:`~torchrl.envs.step_mdp`. Precomputes the list of keys to include and exclude during a call to step_mdp to reduce runtime. @@ -339,48 +339,47 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", - done_keys: Union[NestedKey, List[NestedKey]] = "done", - action_keys: Union[NestedKey, List[NestedKey]] = "action", + reward_keys: NestedKey | List[NestedKey] = "reward", + done_keys: NestedKey | List[NestedKey] = "done", + action_keys: NestedKey | List[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. - The arguments allow for a precise control over what should be kept and what + The arguments allow for precise control over what should be kept and what should be copied from the ``"next"`` entry. The default behavior is: - move the observation entries, reward and done states to the root, exclude - the current action and keep all extra keys (non-action, non-done, non-reward). + move the observation entries, reward, and done states to the root, exclude + the current action, and keep all extra keys (non-action, non-done, non-reward). Args: - tensordict (TensorDictBase): tensordict with keys to be renamed - next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. + tensordict (TensorDictBase): The tensordict with keys to be renamed. + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + from the ``"next"`` entry (if present). Default is ``True``. + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + from the ``"next"`` entry (if present). Default is ``False``. + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). - Default is ``True``. - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults + the ``"next"`` entry). Default is ``True``. + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults to "reward". - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults to "done". - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults to "action". Returns: - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. + + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the + key values to reduce the overhead of making a step in the MDP. Examples: - This funtion allows for this kind of loop to be used: >>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ @@ -783,7 +782,9 @@ def check_env_specs( from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict_select.unbind(-1), fake_tensordict_select.unbind(-1) + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1367,6 +1368,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key diff --git a/tutorials/sphinx-tutorials/beam_search_with_gpt.py b/tutorials/sphinx-tutorials/beam_search_with_gpt.py new file mode 100644 index 00000000000..a3214e89b4e --- /dev/null +++ b/tutorials/sphinx-tutorials/beam_search_with_gpt.py @@ -0,0 +1,411 @@ +""" +Beam Search with TorchRL +======================== + +Key learning +------------ + +In this tutorial, you will learn how to use TorchRL to implement beam search for efficient text generation. +You will understand how to define a policy, build an environment, and run the policy using a beam search algorithm. + +Introduction +------------ +Text generation is a fundamental task in natural language processing (NLP) that has numerous applications in chatbots, +language translation, and content creation. One of the challenges in text generation is efficiently exploring the vast +space of possible sequences to find the most coherent and relevant output. Beam search is a popular heuristic search +algorithm used to address this challenge by maintaining a set of candidate solutions (or "beams") at each step and +selecting the top-scoring candidates to move forward to the next step. + + +Introduction to Beam Search +--------------------------- + +Beam search is a heuristic search algorithm used in many natural language processing tasks, including machine +translation, summarization, and text generation. It works by maintaining a set of candidate solutions (or "beams") at +each step, and selecting the top-scoring candidates to move forward to the next step. + +""" +import argparse + +import torch +import tqdm +from tensordict import NonTensorStack, TensorDict +from tensordict.nn import ( + ProbabilisticTensorDictModule as Prob, + TensorDictModule as Mod, + TensorDictSequential as Seq, +) +from torch.distributions import Categorical + +from torchrl._utils import _make_ordinal_device +from torchrl.data import MCTSForest + +from torchrl.envs import LLMHashingEnv +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, pipeline + +try: + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +parser = argparse.ArgumentParser() +parser.add_argument( + "--pretrained", + type=bool, + default=not is_sphinx, + help="Set to True to load pre-trained weights, False for random weights.", +) +parser.add_argument( + "--model", + choices=["llama3.1", "gpt2"], + default="gpt2", + help="Choose the model to use: 'llama3.1' or 'gpt2'.", +) +parser.add_argument( + "--beta", type=int, default=3, help="Set the beta parameter for the model." +) +parser.add_argument( + "--pool", type=int, default=1000, help="Set the pool size for processing." +) +parser.add_argument( + "--nsteps", type=int, default=10, help="Set the number of steps for the process." +) +parser.add_argument( + "--device", + type=str, + default=None, + help="Specify the device to use (e.g., 'cpu', 'cuda').", +) +parser.add_argument( + "--device_map", + type=str, + default="auto", + help="Specify the device map for model parallelism (e.g., 'auto').", +) + +args = parser.parse_args( + [ + # When executing this in a notebook, change the parameters here, eg + # "--device", "cuda:0" + ] +) + +################################################ +# Build the model +# --------------- +# In this example, we use a pre-trained GPT-2 model as our language model. +# We define a GPTWrapper class to wrap the GPT-2 model and return the output as a TensorDict. + +if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + if args.pretrained: + cfg = GPT2Config.from_pretrained("openai-community/gpt2") + else: + cfg = GPT2Config() + llm = GPT2LMHeadModel(cfg).eval().requires_grad_(False) + + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + +elif args.model == "llama3.1": + if not args.pretrained: + raise ValueError("llama3.1 can only be used with --pretrained=True") + + model_id = "meta-llama/Llama-3.1-8B" + + if args.device: + args.device_map = None + pipeline = pipeline( + "text-generation", + model=model_id, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map=args.device_map, + device=args.device, + ) + + tokenizer = pipeline.tokenizer + llm = pipeline.model.eval().requires_grad_(False) + if args.device: + device = _make_ordinal_device(args.device) + elif torch.cuda.is_available(): + device = "cuda:0" + elif torch.mps.is_available(): + torch.mps.empty_cache() + device = "mps:0" + else: + device = "cpu" + +torch.set_default_device(device) + +text_to_tensor = Seq( + Mod(tokenizer, in_keys=["query"], out_keys=["out"]), + # A renaming layer + Mod(lambda x: x, in_keys=[("out", "input_ids")], out_keys=["observation"]), +).select_out_keys("observation") +td = TensorDict( + query=NonTensorStack.from_list(["hello world! Give me a high five"] * 4), + batch_size=[4], +) +print(text_to_tensor(td)) + + +################################################ +# LLM Environment with Hashing +# ---------------------------- +# +# This environment represents a dataset of text sequences as a Markov Decision Process (MDP), where each observation is +# reduced to a unique integer using a hashing module. +# +# By hashing observations, we can efficiently store and retrieve them: instead of identifying nodes with their +# associated (observation, action) pair, we use a (hash, action) pair. This approach has multiple advantages: +# +# - Observations have a variable shape, making it hard to preallocate storage or store them contiguously and efficiently +# in memory. +# - Using observations directly incurs extra memory cost as duplicated data will be stored (since successive steps in a +# branch share several tokens). Successive nodes only differ by an extra action (i.e., an extra token), and +# lateral (sibling) nodes differ only by their action. +# The only information we need to store is the action associated with the node itself. To reconstruct the sequence of +# tokens up to a certain node, we concatenate the actions along the path to that node - an operation for which +# torchrl has all the necessary tooling. +# +# The environment has two main methods: `_reset`, which initializes the environment with a given observation, and +# `_step`, which takes an action (i.e., the next token in the sequence) and returns the updated observation. +# +# .. figure:: /_static/img/rollout-llm.png +# :alt: Data collection loop with our LLM environment. +# + +env = LLMHashingEnv(vocab_size=tokenizer.vocab_size, tokenizer=tokenizer) + +################################################ +# Define the policy +# ----------------- +# +# In this section, we define a "policy" that takes an observation as input and outputs an action. Note that in our +# context, the term "policy" is used to fit into control frameworks, but at its core, it's simply a language model +# with some additional pre and post-processing steps. +# +# Policy Architecture +# ~~~~~~~~~~~~~~~~~~~ +# +# Our policy consists of a sequence of modules +# +# 1. Select unique observations (or nodes) in the input data. +# 2. LLMWrapper: This module wraps the GPT-2 model and provides a convenient interface for generating output. +# 3. Select last logit: This module selects the last logit from the output of the LLMWrapper. +# 4. Probabilistic sampling: This module samples from a categorical distribution to select the next token. +# 5. Reshape: This module reshapes the output to a 1D tensor. +# 6. Top-k selection: This module selects the top-k tokens with the highest probabilities. +# +# Selecting unique observations +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# It might be the case that the policy receives multiple identical states in the batch. +# To make the computation more efficient, we want to select only the unique values of these observations. +# Doing this also reduces the chances that we will be generating redundant trajectories. +# Because we have hashes that uniquely define the trajectory up to a given step, it's easier to use these hash values +# to pinpoint the unique nodes rather than using the observations directly. +# +# Notice that indexing the relevant nodes is made easy thanks to tensordict's API! +# + + +def select_unique_obs(td): + # Get the obs (the hash) + hashes = td["hash"] + hashes = hashes.squeeze() + assert hashes.ndim == 1 + # the indices of the unique values are the unique values of the inverse indices returned from `unique` + _, unique_hashes = torch.unique(hashes, dim=0, return_inverse=True) + unique_hashes = unique_hashes.unique() + return td[unique_hashes] + + +################################################ +# LLMWrapper +# ~~~~~~~~~~ +# +# The LLMWrapper module is a simple wrapper around the LLM. It takes the observation (i.e., the current text) +# as input and outputs the result of the LLM presented as a TensorDict instance. +# + + +class LLMWrapper(torch.nn.Module): + def __init__(self, gpt): + super().__init__() + self.gpt = gpt + + def forward(self, x: torch.Tensor) -> TensorDict: + result = TensorDict.from_dataclass(self.gpt(x, return_dict=True), device=device) + return result + + +llm_module = Mod(LLMWrapper(llm), in_keys=["observation"], out_keys=["data"]) + +################################################ +# Select last logits +# ~~~~~~~~~~~~~~~~~~ +# +# To select the best actions, we are only going to look at the last logit of the sequence. Another option could +# be to aggregate the logits together using a :meth:`~torch.Tensor.sum` operator. + +select_last = Mod( + lambda x: x[:, -1:], in_keys=[("data", "logits")], out_keys=["logits"] +) + +################################################ +# Probabilistic Sampling +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# The probabilistic sampling module samples from a categorical distribution to select the next token. +# We use a custom ``CategoricalWithoutReplacement`` class to ensure that the same token is not selected twice. +# +# We then use a :class:`~tensordict.nn.ProbabilisticTensorDictModule` to build the distribution from the logits +# and sample from it on-the-fly. Through the ``log_prob_key`` keyword argument, we indicate that we want to register +# the value of the log-probability in the tensordict (which we will need for the Beam search algorithm). +# + + +class CategoricalWithoutReplacement(Categorical): + def sample(self, sample_shape=()) -> torch.Tensor: + n = sample_shape.numel() + probs = self.probs + probs_shape = probs.shape + if len(probs_shape) > 2: + probs = probs.flatten(0, -2) + samples = torch.multinomial(probs, n, replacement=False) + return samples.view((*sample_shape, *probs_shape[:-1])) + + +prob_module = Prob( + in_keys=["logits"], + out_keys=["action"], + default_interaction_type="random", + distribution_class=CategoricalWithoutReplacement, + return_log_prob=True, + log_prob_key="logits_select", + num_samples=args.pool, +) + + +################################################ +# Top-k Selection +# ~~~~~~~~~~~~~~~ +# +# The top-k selection module selects the top-k tokens with the highest probabilities. + + +def select_top_k(td: TensorDict, top_k=args.beta) -> TensorDict: + logits = td["logits_select"] + topk = logits.topk(top_k, dim=0) + topk_indices = topk.indices.squeeze(-1) + return td[topk_indices].set("topk_indices", topk_indices) + + +################################################ +# Putting modules together using :class:`~tensordict.nn.TensorDictSequential` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Orchestrating :class:`~tensordict.nn.TensorDictModule` instances is easy, as every module receives and returns a +# TensorDict instance. Whether one, two or more (or even none!) tensors will need to be accessed, we can just +# concatenate them in a sequence and let :class:`~tensordict.nn.TensorDictSequential` do the rest. +# + +policy = Seq( + # Only get the unique obs + select_unique_obs, + # Call to the LLM + llm_module, + # Select last logit + select_last, + # Sample + prob_module, + # Reshape to -1 + lambda td: td.reshape(-1), + # Top-k + select_top_k, +) + +################################################ +# Check specs +# ----------- +# +# Verify that the environment's observation and action specs match the input data. + +x = tokenizer(["Check out TorchRL!"])["input_ids"] +td = TensorDict(observation=x, batch_size=[1]) +td = env.reset(td) +env.check_env_specs(tensordict=td, return_contiguous=False) + +################################################ +# Create a forest to store the data +# --------------------------------- +# +# A :class:`~torchrl.data.MCTSForest` is a collection of trees. You can think of it as a dynamic dataset +# where we will register every new entry (node, as defined by an observation) using the previous `(observation, action)` +# pair. Later, we will be able to query the Forest for a tree given a starting node through the +# :meth:`~torchrl.data.MCTSForest.get_tree` method. +# + +forest = MCTSForest(observation_keys=["hash"], action_keys=["action", "logits_select"]) + +################################################ +# Run policy +# ---------- +# +# Here comes the fun part: we will execute the policy and generate new token sequences from it. +# + +with torch.no_grad(): + # Total number of candidates + pool = args.pool + # Number of selected beams + beta = args.beta + x = tokenizer(["Check out TorchRL!"])["input_ids"] + reset_td = env.reset( + TensorDict(observation=x, batch_size=[1]).repeat_interleave(args.beta) + ) + tds = [] + # beam search + td = reset_td + reset_td = reset_td[0].clone() + + pbar = tqdm.tqdm(range(args.nsteps)) + for _ in pbar: + td = policy(td) + next_td = env.step(td) + + tds.append(next_td) + next_td_filtered = next_td.exclude( + "observation", "text", ("next", "observation"), ("next", "text") + ) + forest.extend(next_td_filtered) + pbar.set_description(f"Forest length: {len(forest)}") + + print("action", next_td["action"]) + td = env.step_mdp(next_td) + print("hash", td["hash"]) + + tds = TensorDict.lazy_stack(tds, -1) + for i in range(tds.shape[0]): + print(tds[i, -1]["next", "text"]) + + tree = forest.get_tree(reset_td) + valid_paths = list(tree.valid_paths()) + print("valid paths", valid_paths) + + for path in valid_paths: + rollout = tree.rollout_from_path(path) + print("Check out TorchRL!", tokenizer.decode(rollout["action"].squeeze(-1))) + print(rollout["logits_select"].sum()) + + def make_labels(local_tree, path): + if path: + r = tree.rollout_from_path(path) + actions = r["action"] + return "Check out TorchRL! " + tokenizer.decode(actions.squeeze(-1)) + return "Check out TorchRL!" + + tree.plot(make_labels=make_labels)