Skip to content

Commit

Permalink
Interface cleanup, in progress: expose more types
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jul 20, 2024
1 parent efa0709 commit 98f1a46
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 47 deletions.
9 changes: 6 additions & 3 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ type deviceptr =
| Deviceptr of Unsigned.uint64
(** A pointer to an array on a device. (Not a pointer to a device!) *)

let string_of_deviceptr (Deviceptr id) = Unsigned.UInt64.to_hexstring id
let sexp_of_deviceptr ptr = Sexplib0.Sexp.Atom (string_of_deviceptr ptr)

let mem_alloc ~size_in_bytes =
let open Ctypes in
let deviceptr = allocate_n cu_deviceptr ~count:1 in
Expand Down Expand Up @@ -1394,7 +1397,7 @@ let uint_of_attach_mem f =
| Mem_host -> Unsigned.UInt.of_int64 cu_mem_attach_host
| Mem_single_stream -> Unsigned.UInt.of_int64 cu_mem_attach_single

let stream_attach_mem_async stream device length flag =
let stream_attach_mem_async stream (Deviceptr device) length flag =
check "cu_stream_attach_mem_async"
@@ Cuda.cu_stream_attach_mem_async stream.stream device (Unsigned.Size_t.of_int length)
@@ uint_of_attach_mem flag
Expand Down Expand Up @@ -1447,5 +1450,5 @@ type func = cu_function
type module_ = cu_module
type limit = cu_limit
type device = cu_device
type nonrec nvrtc_result = Nvrtc_ffi.Bindings_types.nvrtc_result
type cuda_result = Cuda_ffi.Bindings_types.cu_result
type nonrec nvrtc_result = Nvrtc_ffi.Bindings_types.nvrtc_result [@@deriving sexp]
type cuda_result = Cuda_ffi.Bindings_types.cu_result [@@deriving sexp]
57 changes: 14 additions & 43 deletions cudajit.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ type func
type module_
type limit
type device
type nvrtc_result
type cuda_result
type error_code = Nvrtc_error of nvrtc_result | Cuda_error of cuda_result

val error_code_of_sexp : Sexplib0.Sexp.t -> error_code
val sexp_of_error_code : error_code -> Sexplib0.Sexp.t
type nvrtc_result [@@deriving sexp]
type cuda_result [@@deriving sexp]
type error_code = Nvrtc_error of nvrtc_result | Cuda_error of cuda_result [@@deriving sexp]

exception Error of { status : error_code; message : string }

type compile_to_ptx_result = { log : string option; ptx : char Ctypes.ptr; ptx_length : int }
type compile_to_ptx_result

val compile_to_ptx :
cu_src:string -> name:string -> options:string list -> with_debug:bool -> compile_to_ptx_result
Expand Down Expand Up @@ -66,27 +63,15 @@ type jit_option =
| JIT_GENERATE_LINE_INFO of bool
| JIT_CACHE_MODE of Cuda_ffi.Bindings_types.cu_jit_cache_mode
| JIT_POSITION_INDEPENDENT_CODE of bool
[@@deriving sexp]

val jit_option_of_sexp : Sexplib0.Sexp.t -> jit_option
val sexp_of_jit_option : jit_option -> Sexplib0.Sexp.t
val uint_of_cu_jit_target : Cuda_ffi.Bindings_types.cu_jit_target -> Unsigned.uint
val uint_of_cu_jit_fallback : Cuda_ffi.Bindings_types.cu_jit_fallback -> Unsigned.uint
val uint_of_cu_jit_cache_mode : Cuda_ffi.Bindings_types.cu_jit_cache_mode -> Unsigned.uint
val module_load_data_ex : compile_to_ptx_result -> jit_option list -> module_
val module_get_function : module_ -> name:string -> Cuda_ffi.Bindings_types.cu_function
val module_get_function : module_ -> name:string -> func

type deviceptr = Deviceptr of Unsigned.uint64
type deviceptr [@@deriving sexp_of]

val string_of_deviceptr : deviceptr -> string
val mem_alloc : size_in_bytes:int -> deviceptr

val memcpy_H_to_D_impl :
?host_offset:int ->
?length:int ->
dst:'a ->
src:('b, 'c, 'd) Bigarray.Genarray.t ->
(dst:'a -> src:unit Ctypes_static.ptr -> size_in_bytes:int -> 'e) ->
'e

val memcpy_H_to_D_unsafe : dst:deviceptr -> src:unit Ctypes.ptr -> size_in_bytes:int -> unit

val memcpy_H_to_D :
Expand Down Expand Up @@ -122,7 +107,7 @@ type kernel_param =
val no_stream : stream

val launch_kernel :
Cuda_ffi.Bindings_types.cu_function ->
func ->
grid_dim_x:int ->
?grid_dim_y:int ->
?grid_dim_z:int ->
Expand All @@ -135,15 +120,6 @@ val launch_kernel :
unit

val ctx_synchronize : unit -> unit

val memcpy_D_to_H_impl :
?host_offset:int ->
?length:int ->
dst:('a, 'b, 'c) Bigarray.Genarray.t ->
src:'d ->
(dst:unit Ctypes_static.ptr -> src:'d -> size_in_bytes:int -> 'e) ->
'e

val memcpy_D_to_H_unsafe : dst:unit Ctypes.ptr -> src:deviceptr -> size_in_bytes:int -> unit

val memcpy_D_to_H :
Expand Down Expand Up @@ -356,20 +332,15 @@ type device_attributes = {
unified_function_pointers : bool;
multicast_supported : bool;
}
[@@deriving sexp]

val device_attributes_of_sexp : Sexplib0.Sexp.t -> device_attributes
val sexp_of_device_attributes : device_attributes -> Sexplib0.Sexp.t
val device_get_attributes : device -> device_attributes
val ctx_set_limit : Cuda_ffi.Bindings_types.cu_limit -> int -> unit
val ctx_get_limit : Cuda_ffi.Bindings_types.cu_limit -> Unsigned.size_t
val ctx_set_limit : limit -> int -> unit
val ctx_get_limit : limit -> Unsigned.size_t

type attach_mem = Mem_global | Mem_host | Mem_single_stream
type attach_mem = Mem_global | Mem_host | Mem_single_stream [@@deriving sexp]

val attach_mem_of_sexp : Sexplib0.Sexp.t -> attach_mem
val sexp_of_attach_mem : attach_mem -> Sexplib0.Sexp.t
val uint_of_attach_mem : attach_mem -> Unsigned.uint
val stream_attach_mem_async : stream -> Unsigned.uint64 -> int -> attach_mem -> unit
val uint_of_cu_stream_flags : non_blocking:bool -> Unsigned.uint
val stream_attach_mem_async : stream -> deviceptr -> int -> attach_mem -> unit
val stream_create : ?non_blocking:bool -> ?lower_priority:int -> unit -> stream
val stream_destroy : stream -> unit
val stream_get_context : stream -> context
Expand Down
2 changes: 1 addition & 1 deletion test_no_device/saxpy_ptx.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let%expect_test "SAXPY compilation" =
Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ]
~with_debug:true
in
(match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log);
(match Cudajit.compilation_log prog with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log);
[%expect {| CUDA Compile log: |}];
Format.printf "PTX: %s%!"
@@ Str.global_replace
Expand Down

0 comments on commit 98f1a46

Please sign in to comment.