Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dot product #2

Merged
merged 3 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- uses: jiro4989/setup-nim-action@v1
- uses: jiro4989/setup-nim-action@v2
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install dependencies
Expand Down
51 changes: 36 additions & 15 deletions src/cuda.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# CUDA runtime C++ FFI
import std/strformat
import std/strformat, macros

type
size_t* = uint64
Expand Down Expand Up @@ -59,6 +59,8 @@ proc cudaLaunchKernel*(function_address: pointer; numBlocks: Dim3; dimBlocks: Di
# args: ptr pointer; sharedMemBytes: csize_t; stream: cudaStream_t): cint {.
# importcpp: "cudaLaunchKernel(@)", header: "cuda_runtime.h".}
proc cudaDeviceSynchronize*(): cudaError_t {.header: "cuda_runtime.h",importcpp: "cudaDeviceSynchronize(@)".}
proc cudaSyncthreads*() {.importcpp: "__syncthreads()", header: "cuda_runtime.h".}
proc hippoSyncthreads*() {.importcpp: "__syncthreads()", header: "cuda_runtime.h".}

proc cudaLaunchKernelGGL*(
function_address: proc;
Expand Down Expand Up @@ -103,17 +105,36 @@ let
gridDim* {.importc, inject, header: "cuda_runtime.h".}: GridDim
threadIdx* {.importc, inject, header: "cuda_runtime.h".}: ThreadIdx

template hippoGlobal*(body: untyped) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__global__ $# $#$#".}
body
{.pop}

template hippoDevice*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__device__ $# $#$#".}
body
{.pop}

template hippoHost*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__host__ $# $#$#".}
body
{.pop}
macro hippoGlobal*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__global__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}

macro hippoDevice*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__device__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}


macro hippoHost*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__host__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}
51 changes: 36 additions & 15 deletions src/hip.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# HIP runtime C++ FFI
import std/strformat
import std/strformat, macros

type
size_t* = uint64
Expand Down Expand Up @@ -56,6 +56,8 @@ proc hipLaunchKernel*(function_address: pointer; numBlocks: Dim3; dimBlocks: Dim
# args: ptr pointer; sharedMemBytes: csize_t; stream: hipStream_t): cint {.
# importcpp: "hipLaunchKernel(@)", header: "hip/hip_runtime.h".}
proc hipDeviceSynchronize*(): hipError_t {.header: "hip/hip_runtime.h",importcpp: "hipDeviceSynchronize(@)".}
proc hipSyncthreads*() {.importcpp: "__syncthreads()", header: "hip/hip_runtime.h".}
proc hippoSyncthreads*() {.importcpp: "__syncthreads()", header: "hip/hip_runtime.h".}

proc hipLaunchKernelGGL*(
function_address: proc;
Expand Down Expand Up @@ -99,17 +101,36 @@ let
gridDim* {.importc, inject, header: "hip/hip_runtime.h".}: GridDim
threadIdx* {.importc, inject, header: "hip/hip_runtime.h".}: ThreadIdx

template hippoGlobal*(body: untyped) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__global__ $# $#$#".}
body
{.pop}

template hippoDevice*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__device__ $# $#$#".}
body
{.pop}

template hippoHost*(body: typed) =
{.push stackTrace: off, checks: off, exportc, codegenDecl: "__host__ $# $#$#".}
body
{.pop}
macro hippoGlobal*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__global__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}

macro hippoDevice*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__device__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}


macro hippoHost*(fn: untyped): untyped =
let globalPragma: NimNode = quote:
{. exportc, codegenDecl: "__host__ $# $#$#".}

fn.addPragma(globalPragma[0])
fn.addPragma(globalPragma[1])
quote do:
{.push stackTrace: off, checks: off.}
`fn`
{.pop.}
2 changes: 1 addition & 1 deletion src/hippo.nim
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ proc launchKernel*(
kernel: proc,
gridDim: Dim3 = newDim3(1,1,1), # default to a grid of 1 block
blockDim: Dim3 = newDim3(1,1,1), # default to 1 thread per block
sharedMemBytes: uint32 = 0,
sharedMemBytes: uint32 = 0, # TODO dynamic shared memory
stream: HippoStream = nil,
args: tuple
): HippoError =
Expand Down
95 changes: 95 additions & 0 deletions tests/hip/dot.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import hippo

# GPU Dot product

const
N = 33 * 1024
ThreadsPerBlock: int = 256
BlocksPerGrid: int = min(32, ((N + ThreadsPerBlock - 1) div ThreadsPerBlock))

# TODO improve this
{.pragma: hippoShared, exportc, codegenDecl: "__shared__ $# $#".}

proc dot(a, b, c: ptr[cfloat]){.hippoGlobal.} =

var cache {.hippoShared.}: ptr[cfloat]

# TODO figure out how to do this properly
let aArray = cast[ptr UncheckedArray[cfloat]](a)
let bArray = cast[ptr UncheckedArray[cfloat]](b)
let cArray = cast[ptr UncheckedArray[cfloat]](c)
let cacheArray = cast[ptr UncheckedArray[cfloat]](cache)

let cacheIndex = threadIdx.x
var tid = threadIdx.x + blockIdx.x * blockDim.x
var temp: cfloat = 0
while tid < N:
temp += aArray[tid] * bArray[tid]
tid += blockDim.x * gridDim.x

# set the cache values
cacheArray[cacheIndex] = temp

# synchronize threads in this block
hippoSyncthreads()

# for reductions, threadsPerBlock must be a power of 2
# because of the following code
var i = blockDim.x div 2
while i != 0:
if cacheIndex < i:
cacheArray[cacheIndex] += cacheArray[cacheIndex + i]
hippoSyncthreads()
i = i div 2

if cacheIndex == 0:
cArray[blockIdx.x] = cacheArray[0]


proc main() =
var a, b, partial_c: array[N, int32]
var dev_a, dev_b, dev_partial_c: pointer

# allocate gpu memory
handleError(hipMalloc(addr dev_a, sizeof(float)*N))
handleError(hipMalloc(addr dev_b, sizeof(float)*N))
handleError(hipMalloc(addr dev_partial_c, BlocksPerGrid * sizeof(float)))

# fill in host memory with data
for i in 0 ..< N:
a[i] = i.int32
b[i] = (i * 2).int32

# copy data to device
handleError(hipMemcpy(dev_a, addr a[0], sizeof(float)*N, hipMemcpyHostToDevice))
handleError(hipMemcpy(dev_b, addr b[0], sizeof(float)*N, hipMemcpyHostToDevice))

# launch kernel
handleError(launchKernel(
dot,
gridDim = newDim3(BlocksPerGrid.uint32),
blockDim = newDim3(ThreadsPerBlock.uint32),
args = (dev_a, dev_b, dev_partial_c)
))

# copy memory back from GPU to CPU
handleError(hipMemcpy(addr partial_c[0], dev_partial_c, BlocksPerGrid * sizeof(float), hipMemcpyDeviceToHost))

# finish up on the CPU
var c: int32 = 0
for i in 0 ..< BlocksPerGrid:
c += partial_c[i]


proc sum_squares(x: float): float =
result = x * (x + 1) * (2 * x + 1) / 6

echo "Does GPU value ", c.float, " = ", 2 * sum_squares((N - 1)), "?"

handleError(hipFree(dev_a))
handleError(hipFree(dev_b))
handleError(hipFree(dev_partial_c))


when isMainModule:
main()
Loading