Skip to content

Commit

Permalink
macro fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monofuel committed Jul 6, 2024
1 parent 105d5f7 commit f91f363
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 40 deletions.
54 changes: 34 additions & 20 deletions src/cuda.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# CUDA runtime C++ FFI
import std/strformat
import std/strformat, macros

type
size_t* = uint64
Expand Down Expand Up @@ -105,22 +105,36 @@ let
gridDim* {.importc, inject, header: "cuda_runtime.h".}: GridDim
threadIdx* {.importc, inject, header: "cuda_runtime.h".}: ThreadIdx

template hippoGlobal*(body: untyped) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__global__ $# $#$#".}
body
{.pop}

template hippoDevice*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__device__ $# $#$#".}
body
{.pop}

template hippoHost*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__host__ $# $#$#".}
body
{.pop}

template hippoShared*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__shared__ $# $#$#".}
body
{.pop}
macro hippoGlobal*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__global__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}

macro hippoDevice*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__device__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}


macro hippoHost*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__host__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}
54 changes: 34 additions & 20 deletions src/hip.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# HIP runtime C++ FFI
import std/strformat
import std/strformat, macros

type
size_t* = uint64
Expand Down Expand Up @@ -101,22 +101,36 @@ let
gridDim* {.importc, inject, header: "hip/hip_runtime.h".}: GridDim
threadIdx* {.importc, inject, header: "hip/hip_runtime.h".}: ThreadIdx

template hippoGlobal*(body: untyped) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__global__ $# $#$#".}
body
{.pop}

template hippoDevice*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__device__ $# $#$#".}
body
{.pop}

template hippoHost*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__host__ $# $#$#".}
body
{.pop}

template hippoShared*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__shared__ $# $#$#".}
body
{.pop}
macro hippoGlobal*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__global__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}

macro hippoDevice*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__device__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}


macro hippoHost*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__host__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}
3 changes: 3 additions & 0 deletions tests/hip/dot.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ const
ThreadsPerBlock: int = 256
BlocksPerGrid: int = min(32, ((N + ThreadsPerBlock - 1) div ThreadsPerBlock))

# TODO improve this
{.pragma: hippoShared, exportc, codegenDecl: "__shared__ $# $#".}

proc dot(a, b, c: ptr[cfloat]){.hippoGlobal.} =

var cache {.hippoShared.}: ptr[cfloat]
Expand Down

0 comments on commit f91f363

Please sign in to comment.