Skip to content

Commit

Permalink
Tensor.cummax (tinygrad#7854)
Browse files Browse the repository at this point in the history
generalized the existing cumsum and take Ops.MAX in addition to Ops.ADD
  • Loading branch information
chenyuxyz authored Nov 22, 2024
1 parent fb10ea5 commit 3b26e51
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/tensor/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
::: tinygrad.Tensor.matmul
::: tinygrad.Tensor.einsum
::: tinygrad.Tensor.cumsum
::: tinygrad.Tensor.cummax
::: tinygrad.Tensor.triu
::: tinygrad.Tensor.tril
::: tinygrad.Tensor.interpolate
Expand Down
2 changes: 1 addition & 1 deletion extra/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):

# approximate top p
# because we are already limited to top k elements we can do top p "without sorting"
output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
output = (output_cumsum >= (1 - p)) * output
output_indices = (output_cumsum >= (1 - p)) * output_indices

Expand Down
3 changes: 2 additions & 1 deletion test/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.engine.search import get_kernel_actions
from tinygrad.ops import Ops

class TestArange(unittest.TestCase):
def _get_flops(self, N, opts=None):
Expand Down Expand Up @@ -86,7 +87,7 @@ def test_manual_index(self):
print("*** indexing ***")
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, 256, 16384, 1)
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
Expand Down
21 changes: 21 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,27 @@ def test_cumsum_zero_axis(self):
helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))

def test_small_cummax(self):
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
def test_simple_cummax(self):
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
def test_cummax(self):
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
# TODO: torch allows this?
# self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError)
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1))
def test_cummax_zero_axis(self):
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))

def test_argmax(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class GroupOp:
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}

# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)

def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)

Expand Down
57 changes: 37 additions & 20 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
Expand Down Expand Up @@ -607,7 +607,7 @@ def arange(start, stop=None, step=1, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)

@staticmethod
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
Expand Down Expand Up @@ -2191,15 +2191,28 @@ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) ->
"""
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)

def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
assert self.shape[axis] != 0
pl_sz = self.shape[axis] - int(not _first_zero)
return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
pl_sz = self.shape[axis] - int(not _include_initial)
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)

def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
axis = self._resolve_dim(axis)
if self.ndim == 0 or 0 in self.shape: return self
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))

def cumsum(self, axis:int=0) -> Tensor:
"""
Computes the cumulative sum of the tensor along the specified axis.
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
Computes the cumulative sum of the tensor along the specified `axis`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
Expand All @@ -2209,17 +2222,21 @@ def cumsum(self, axis:int=0) -> Tensor:
print(t.cumsum(1).numpy())
```
"""
axis = self._resolve_dim(axis)
if self.ndim == 0 or 0 in self.shape: return self
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base_add)
return self._split_cumalu(axis, Ops.ADD)

def cummax(self, axis:int=0) -> Tensor:
"""
Computes the cumulative max of the tensor along the specified `axis`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, -1, 2, -2, 3, -3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cummax(0).numpy())
```
"""
return self._split_cumalu(axis, Ops.MAX)

@staticmethod
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
Expand Down

0 comments on commit 3b26e51

Please sign in to comment.