Skip to content

Commit

Permalink
update docs, remove corealize (tinygrad#4264)
Browse files Browse the repository at this point in the history
* update docs, remove corealize

* handle 0 line count

* tensor schedule
  • Loading branch information
geohot authored Apr 23, 2024
1 parent 9b7efa7 commit 967638f
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 33 deletions.
20 changes: 20 additions & 0 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi
::: tinygrad.lazy.LazyBuffer
options:
show_source: false

## Lowering

The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.

::: tinygrad.ops.ScheduleItem

The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with

::: tinygrad.engine.realize.lower_schedule

## Execution

Creating `ExecItem`, which has a run method

::: tinygrad.engine.realize.ExecItem
options:
members: true

Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)
3 changes: 2 additions & 1 deletion docs/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

## tinygrad ops

::: tinygrad.Tensor.corealize
::: tinygrad.Tensor.schedule_with_vars
::: tinygrad.Tensor.schedule
::: tinygrad.Tensor.realize
::: tinygrad.Tensor.replace
::: tinygrad.Tensor.assign
Expand Down
2 changes: 2 additions & 0 deletions serve_docs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
mkdocs serve -w tinygrad/
6 changes: 3 additions & 3 deletions test/external/external_benchmark_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_layer(self, layer_i, slice_i):
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
def _test_layer(self, name, layer, cin, xy):
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)])
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)])

JITCNT = getenv("JITCNT", 1)
Tensor.training = True
Expand All @@ -67,8 +67,8 @@ def step(x):

y = x.sequential(layer).contiguous().contiguous_backward()
y.sum().backward()
if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step())
else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params])
if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step())
else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params])
return y.detach()

CNT = getenv("CNT", 5)
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_loop_right(self):

@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptWChild(unittest.TestCase):
@unittest.skip("this no longer happens, use corealize")
@unittest.skip("this no longer happens, use realize")
def test_unrealized_child(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)
Expand Down
4 changes: 2 additions & 2 deletions test/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_crossover_assign(self):
b = Tensor.full((4,), 3).contiguous().realize()
a += b
b += a
Tensor.corealize([a,b])
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8)

Expand All @@ -183,7 +183,7 @@ def test_crossunder_assign(self):
c = a+9
a += b
b += c
Tensor.corealize([a,b])
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)

Expand Down
2 changes: 1 addition & 1 deletion test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()):
for s in pre.schedule(seen=seen.copy()):
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
Expand Down
Empty file added tinygrad/engine/__init__.py
Empty file.
24 changes: 12 additions & 12 deletions tinygrad/engine/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) #
def __call__(self, *args, **kwargs) -> ReturnType:
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
Tensor.corealize([x[1] for x in input_tensors])
if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
Expand All @@ -105,20 +105,18 @@ def __call__(self, *args, **kwargs) -> ReturnType:
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])

expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
if self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
if self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1:
# jit capture
self.expected_names: List[Union[int, str]] = expected_names
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
capturing.clear()
del self.buffer_replace
assert len(self.jit_cache), "didn't JIT anything!"
Expand All @@ -133,10 +131,12 @@ def __call__(self, *args, **kwargs) -> ReturnType:

self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
elif self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)

# clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/nn/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, params: List[Tensor], lr: float):
def zero_grad(self):
for param in self.params: param.grad = None

def step(self): Tensor.corealize(self.schedule_step())
def step(self): Tensor.realize(*self.schedule_step())
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError

Expand Down
30 changes: 18 additions & 12 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
import numpy as np

Expand All @@ -11,10 +11,10 @@
from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars

Expand Down Expand Up @@ -146,17 +146,23 @@ def dtype(self) -> DType: return self.lazydata.dtype

# ***** data handlers ****

@staticmethod
def corealize(lst:Iterable[Tensor]):
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
"""Create the schedule needed to realize these Tensor(s), with Variables."""
if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
run_schedule(memory_planner(schedule), var_vals)

def realize(self) -> Tensor:
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
Tensor.corealize([self])
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
return memory_planner(schedule), var_vals

def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
"""Create the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
assert len(var_vals) == 0
return schedule

def realize(self, *lst:Tensor) -> Tensor:
"""Trigger the computation needed to create these Tensor(s)."""
run_schedule(*self.schedule_with_vars(*lst))
return self

def replace(self, x:Tensor) -> Tensor:
Expand Down

0 comments on commit 967638f

Please sign in to comment.