Skip to content

Commit

Permalink
don't const rewrite in cstyle (tinygrad#7442)
Browse files Browse the repository at this point in the history
* don't const rewrite in cstyle

* Update cstyle.py

* simple_symbolic

* fix bfloat16 const on AMD
  • Loading branch information
geohot authored Oct 31, 2024
1 parent bdde795 commit 5dd1ffd
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
Expand Down
8 changes: 8 additions & 0 deletions test_driven_development.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
python3 test/external/process_replay/reset.py
RUN_PROCESS_REPLAY=1 pytest -n auto test/test_tiny.py test/test_uop_graph.py test/test_ops.py test/test_linearizer.py
while true; do
if python3 test/test_tiny.py; then
PYTHONPATH="." python3 test/external/process_replay/process_replay.py
fi
done
4 changes: 2 additions & 2 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
Expand Down Expand Up @@ -525,5 +525,5 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))

# for rendering without sym (including the rules from the renderer)
sink = graph_rewrite(sink, pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render)
sink = graph_rewrite(sink, symbolic_simple+(pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
return sink
11 changes: 7 additions & 4 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1

symbolic = PatternMatcher([
symbolic_simple = PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
Expand Down Expand Up @@ -1036,6 +1036,12 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y')), lambda x,y: x|y),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])

symbolic = symbolic_simple+PatternMatcher([
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **
Expand Down Expand Up @@ -1091,9 +1097,6 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
# ** mod **
# mod folding
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])

symbolic_flat = symbolic+PatternMatcher([
Expand Down
10 changes: 6 additions & 4 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
# default const render
(UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
Expand All @@ -49,10 +54,6 @@
])

extra_pm = PatternMatcher([
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="c"), lambda c: UOp.const(dtypes.uint32, c.arg).cast(c.dtype)),
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="c"), lambda c: UOp.const(dtypes.int32, c.arg).cast(c.dtype)),
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
Expand Down Expand Up @@ -396,6 +397,7 @@ class AMDRenderer(CStyleLanguage):
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
# bfloat16 casting
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/renderer/ptx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
Expand Down Expand Up @@ -33,7 +33,7 @@ def render_val(x, dtype):
}

supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = symbolic+PatternMatcher([
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
Expand Down

0 comments on commit 5dd1ffd

Please sign in to comment.