Skip to content

Commit

Permalink
The interface file (uncommented)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jul 20, 2024
1 parent f36319e commit 072a063
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
### Fixed

- A major bug, exacerbated by the asynchronous functionaliy of v0.3 -- functions performing asynchronous calls should keep the call arguments alive; the user should only forget (or free) the arguments after the calls complete (e.g. after synchronizing a stream).
- Only `launch_kernel` needed fixing as I don't think other functions allocate passed arguments.
- Only `launch_kernel` needed fixing as I don't think other async functions allocate passed arguments.
- We hanlde this internally so no API change!

## [0.3.0] 2024-07-05
Expand Down
16 changes: 10 additions & 6 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ let compile_to_ptx ~cu_src ~name ~options ~with_debug =
{ log; ptx; ptx_length = count - 1 }

let string_from_ptx prog = Ctypes.string_from_ptr prog.ptx ~length:prog.ptx_length
let compilation_log prog = prog.log

let check message status =
if status <> CUDA_SUCCESS then raise @@ Error { status = Cuda_error status; message }
Expand Down Expand Up @@ -520,15 +521,16 @@ let memset_d32 (Deviceptr dev) v ~length =
check "cu_memset_d32" @@ Cuda.cu_memset_d32 dev v @@ Unsigned.Size_t.of_int length

let memset_d8_async (Deviceptr dev) v ~length stream =
check "cu_memset_d8_async" @@ Cuda.cu_memset_d8_async dev v (Unsigned.Size_t.of_int length) stream
check "cu_memset_d8_async"
@@ Cuda.cu_memset_d8_async dev v (Unsigned.Size_t.of_int length) stream.stream

let memset_d16_async (Deviceptr dev) v ~length stream =
check "cu_memset_d16_async"
@@ Cuda.cu_memset_d16_async dev v (Unsigned.Size_t.of_int length) stream
@@ Cuda.cu_memset_d16_async dev v (Unsigned.Size_t.of_int length) stream.stream

let memset_d32_async (Deviceptr dev) v ~length stream =
check "cu_memset_d32_async"
@@ Cuda.cu_memset_d32_async dev v (Unsigned.Size_t.of_int length) stream
@@ Cuda.cu_memset_d32_async dev v (Unsigned.Size_t.of_int length) stream.stream

let module_get_global module_ ~name =
let open Ctypes in
Expand Down Expand Up @@ -1394,7 +1396,7 @@ let uint_of_attach_mem f =

let stream_attach_mem_async stream device length flag =
check "cu_stream_attach_mem_async"
@@ Cuda.cu_stream_attach_mem_async stream device (Unsigned.Size_t.of_int length)
@@ Cuda.cu_stream_attach_mem_async stream.stream device (Unsigned.Size_t.of_int length)
@@ uint_of_attach_mem flag

let uint_of_cu_stream_flags ~non_blocking =
Expand All @@ -1419,13 +1421,13 @@ let stream_destroy stream =
let stream_get_context stream =
let open Ctypes in
let ctx = allocate_n cu_context ~count:1 in
check "cu_stream_get_ctx" @@ Cuda.cu_stream_get_ctx stream ctx;
check "cu_stream_get_ctx" @@ Cuda.cu_stream_get_ctx stream.stream ctx;
!@ctx

let stream_get_id stream =
let open Ctypes in
let id = allocate uint64_t Unsigned.UInt64.zero in
check "cu_stream_get_id" @@ Cuda.cu_stream_get_id stream id;
check "cu_stream_get_id" @@ Cuda.cu_stream_get_id stream.stream id;
!@id

let stream_is_ready stream =
Expand All @@ -1445,3 +1447,5 @@ type func = cu_function
type module_ = cu_module
type limit = cu_limit
type device = cu_device
type nonrec nvrtc_result = Nvrtc_ffi.Bindings_types.nvrtc_result
type cuda_result = Cuda_ffi.Bindings_types.cu_result
Loading

0 comments on commit 072a063

Please sign in to comment.