Skip to content

Commit

Permalink
improve pre-commit [pr] (tinygrad#7256)
Browse files Browse the repository at this point in the history
* improve pre-commit [pr]

* mypy passes on windows
  • Loading branch information
geohot authored Oct 24, 2024
1 parent b1a3067 commit de7b9d7
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 33 deletions.
13 changes: 7 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# on Windows -- $env:SKIP="devicetests,tests,example"
repos:
- repo: local
hooks:
Expand All @@ -7,15 +8,15 @@ repos:
language: system
always_run: true
pass_filenames: false
- id: mypy
name: mypy
entry: python3 -m mypy tinygrad/ --strict-equality
- id: tiny
name: tiny tests
entry: python3 -m pytest test/test_tiny.py
language: system
always_run: true
pass_filenames: false
- id: docs2
name: docs2
entry: python3 docs/abstractions2.py
- id: mypy
name: mypy
entry: python3 -m mypy tinygrad/ --strict-equality
language: system
always_run: true
pass_filenames: false
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ explicit_package_bases = True
warn_unreachable = True
warn_redundant_casts = True
# NOTE: had to comment this out to make mypy pass on both CI and OSX
#warn_unused_ignores = True
#warn_unused_ignores = True
6 changes: 3 additions & 3 deletions tinygrad/engine/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class TimeoutException(Exception): pass
def timeout_handler(signum, frame): raise TimeoutException()

def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
if hasattr(signal, "SIGALRM"):
signal.signal(signal.SIGALRM, timeout_handler)
if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
ret = None
Expand All @@ -74,7 +74,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tup
except Exception as e:
if getenv("BEAM_STRICT_MODE"): raise e
finally:
if hasattr(signal, "SIGALRM"): signal.alarm(0)
if hasattr(signal, "alarm"): signal.alarm(0)
return x[0], ret

# workers should ignore ctrl c
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class MetalRenderer(CStyleLanguage):
tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []

# language options
kernel_prefix = "kernel "
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/runtime/ops_amd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Tuple, List, Any
import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, time, array, contextlib, decimal
import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, time, array, contextlib, decimal, sys
assert sys.platform != 'win32'
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWComputeQueue, HWCopyQueue, HCQArgsState, HCQSignal, HCQProgram
from tinygrad.device import BufferOptions
Expand Down
36 changes: 18 additions & 18 deletions tinygrad/runtime/ops_disk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import os, sys, mmap, io, ctypes, ctypes.util, platform, contextlib
import os, sys, mmap, io, ctypes, ctypes.util, contextlib
from typing import Optional, Generator, Tuple, Callable, List
from tinygrad.helpers import OSX, round_up
from tinygrad.device import Compiled, Allocator
Expand Down Expand Up @@ -84,7 +84,7 @@ def _might_open(self, size):
filename = self.dname[len("disk:"):]
self.size = size

if filename.startswith("shm:"):
if sys.platform != "win32" and filename.startswith("shm:"):
fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
os.close(fd)
Expand All @@ -93,7 +93,7 @@ def _might_open(self, size):
except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
if os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
self.mem = mmap.mmap(self.fd, self.size)
if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
def _might_close(self):
self.count -= 1
Expand All @@ -103,22 +103,22 @@ def _might_close(self):
def _iouring_setup(self):
DiskDevice._tried_io_uring_init = True

if platform.system() != 'Linux' or hasattr(sys, "getandroidapilevel"): return
if sys.platform == 'linux' and not hasattr(sys, "getandroidapilevel"):
fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
if fd < 0: return

fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
if fd < 0: return
sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)

sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)
def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail),
array=u32ptr(sq_ptr+p.sq_off.array),
kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))

def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail), array=u32ptr(sq_ptr+p.sq_off.array),
kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))
cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))

cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))

DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore
DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore
3 changes: 2 additions & 1 deletion tinygrad/runtime/ops_dsp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Tuple, Any
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
assert sys.platform != 'win32'
from tinygrad.device import BufferOptions, Compiled, Allocator
from tinygrad.helpers import from_mv, getenv, DEBUG, round_up, mv_address, to_mv, cpu_objdump
from tinygrad.runtime.ops_clang import ClangCompiler
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/runtime/ops_nv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, decimal
import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, decimal, sys
assert sys.platform != 'win32'
from typing import Tuple, List, Any, cast, Union, Dict, Type
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, hcq_command
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/runtime/ops_qcom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import os, ctypes, functools, mmap, struct, array, decimal, math
import os, ctypes, functools, mmap, struct, array, decimal, math, sys
assert sys.platform != 'win32'
from types import SimpleNamespace
from typing import Tuple, List, Any, cast
from tinygrad.device import BufferOptions
Expand Down

0 comments on commit de7b9d7

Please sign in to comment.