Skip to content

Commit

Permalink
Update writing_batching_rules.md (pytorch#108007)
Browse files Browse the repository at this point in the history
Was reading through the batching rules info which is very cool and just saw a couple of typos 😊.

Thanks

Pull Request resolved: pytorch#108007
Approved by: https://github.com/msaroufim
  • Loading branch information
james-a-watson authored and pytorchmergebot committed Aug 26, 2023
1 parent a18ee0c commit 808e088
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions functorch/writing_batching_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Vmap is a function transform (pioneered by Jax) that allows one to batch functio

This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this.

One misconception is that this is some magic compiler voodoo, or that it is inherent some function transform. It is not - and there's another framing of it that might make it more clear.
One misconception is that this is some magic compiler voodoo, or that it is inherently some function transform. It is not - and there's another framing of it that might make it more clear.

Instead of providing `vmap`, imagine that we provide a `BatchedTensor` instead. This `BatchedTensor` wraps a `Tensor[B, N, M]`. *But*, to all the users of this tensor, it looks like a `Tensor[N, M]` (that is, without the `B` dimension). Then, when operations are done on this tensor, it transforms that operation to broadcast over the additional `B` dimension as well.

For example, let's say that we wanted to sum a `BatchedTensor` with shape `[5]` - that is, `torch.sum(x)`. This would give us back a `BatchedTensor` with shape `[]` (i.e. a scalar tensor). **But**, in reality, we this is actually a `Tensor` with shape `[B]`. Instead of running `torch.sum(x: [5])`, we ran `torch.sum(x: [B, 5], dim=1)`. In other words, we transformed the sum operation so that instead of summing the whole tensor, it summed all the dimensions *except* the batch dimension.
For example, let's say that we wanted to sum a `BatchedTensor` with shape `[5]` - that is, `torch.sum(x)`. This would give us back a `BatchedTensor` with shape `[]` (i.e. a scalar tensor). **But**, in reality, this is actually a `Tensor` with shape `[B]`. Instead of running `torch.sum(x: [5])`, we ran `torch.sum(x: [B, 5], dim=1)`. In other words, we transformed the sum operation so that instead of summing the whole tensor, it summed all the dimensions *except* the batch dimension.

That is how `vmap` works. For every single operator, we define how to transform that operator to broadcast over an additional batch dimension.

Expand Down Expand Up @@ -64,8 +64,8 @@ In those cases, we have 2 primary tools - templates and boxed fallbacks. For exa
There are 3 primary boxed fallbacks that we've used (I'll refer to the macros here). If you feel that there's any pattern that we could/should abstract away, feel free to post an issue.

1. `POINTWISE_BOXED`: Handles pointwise ops. Takes all tensors in the arguments, moves batch dimensions to the front, and unsqueezes all tensors so that they broadcast.
1. `REDUCTION_BOXED`: Handles reduction ops. Moves batch dimension to the front, and then modifies the dim argument so that it works with the extra batch dimension. For example, if the dim is an integer, then we add one. If it's a dimarray, then we add one to all entries (unless it's empty!, in which case we fill in all the entries except 0).
1. `VARIADIC_BDIMS_BOXED`: Handles ops that already natively support arbitrary batch dimensions. For example, if it supports `[B1,B2,..., N]`. In this case, we can simply move the batch dimension to the front and we're done!
2. `REDUCTION_BOXED`: Handles reduction ops. Moves batch dimension to the front, and then modifies the dim argument so that it works with the extra batch dimension. For example, if the dim is an integer, then we add one. If it's a dimarray, then we add one to all entries (unless it's empty!, in which case we fill in all the entries except 0).
3. `VARIADIC_BDIMS_BOXED`: Handles ops that already natively support arbitrary batch dimensions. For example, if it supports `[B1,B2,..., N]`. In this case, we can simply move the batch dimension to the front and we're done!

### Sidestepping batching rules by decomposing operators
Sometimes, it's difficult to implement a batching rule by transforming it into another operator. For example, `trace`. In that case, instead of transforming the operator, we can simply decompose it.
Expand Down

0 comments on commit 808e088

Please sign in to comment.