Skip to content

Commit

Permalink
Merge branch 'main' into pr/6492
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Mar 2, 2023
2 parents 951efb1 + 122b5b6 commit 0673e67
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 32 deletions.
3 changes: 3 additions & 0 deletions apps/hannk/interpreter/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class Tensor {
bool host_dirty() const {
return buffer_.host_dirty();
}
int device_sync(void *ctx = nullptr) {
return buffer_.device_sync(ctx);
}

void resize_dynamic(const Box &new_shape);

Expand Down
72 changes: 40 additions & 32 deletions src/runtime/gpu_context_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@ class GPUCompilationCache {
struct CachedCompilation {
ContextT context{};
ModuleStateT module_state{};
uint32_t kernel_id{};
uint32_t use_count{0};

CachedCompilation(ContextT context, ModuleStateT module_state,
uint32_t kernel_id, uint32_t use_count)
: context(context), module_state(module_state),
kernel_id(kernel_id), use_count(use_count) {
}
uintptr_t kernel_id{0};
uintptr_t use_count{0};
};

halide_mutex mutex;
Expand All @@ -27,17 +21,16 @@ class GPUCompilationCache {
CachedCompilation *compilations{nullptr};
int count{0};

static constexpr uint32_t kInvalidId{0};
static constexpr uint32_t kDeletedId{1};
static constexpr uintptr_t kInvalidId{0};
static constexpr uintptr_t kDeletedId{1};

uint32_t unique_id{2}; // zero is an invalid id
uintptr_t unique_id{2}; // zero is an invalid id

public:
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uintptr_t id, int bits) {
uintptr_t addr = (uintptr_t)context + id;
// Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
// in hexadecimal.
if (sizeof(uintptr_t) >= 8) {
if constexpr (sizeof(uintptr_t) >= 8) {
return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
} else {
return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
Expand Down Expand Up @@ -70,7 +63,7 @@ class GPUCompilationCache {
return false;
}

HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id,
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uintptr_t id,
ModuleStateT *&module_state, int increment) {
if (log2_compilations_size == 0) {
return false;
Expand All @@ -94,17 +87,6 @@ class GPUCompilationCache {
return false;
}

HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
ScopedMutexLock lock_guard(&mutex);
uint32_t id = (uint32_t)(uintptr_t)state_ptr;
ModuleStateT *mod_ptr;
if (find_internal(context, id, mod_ptr, 0)) {
module_state = *mod_ptr;
return true;
}
return false;
}

HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
if (size_bits != log2_compilations_size) {
int new_size = (1 << size_bits);
Expand Down Expand Up @@ -135,7 +117,7 @@ class GPUCompilationCache {
}

template<typename FreeModuleT>
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
void release_context_already_locked(void *user_context, bool all, ContextT context, FreeModuleT &f) {
if (count == 0) {
return;
}
Expand All @@ -155,18 +137,38 @@ class GPUCompilationCache {
}
}

public:
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
ScopedMutexLock lock_guard(&mutex);

uintptr_t id = (uintptr_t)state_ptr;
ModuleStateT *mod_ptr;
if (find_internal(context, id, mod_ptr, 0)) {
module_state = *mod_ptr;
return true;
}
return false;
}

template<typename FreeModuleT>
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);

release_context_already_locked(user_context, all, context, f);
}

template<typename FreeModuleT>
void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);

release_context(user_context, false, context, f);
release_context_already_locked(user_context, false, context, f);
}

template<typename FreeModuleT>
void release_all(void *user_context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);

release_context(user_context, true, nullptr, f);
release_context_already_locked(user_context, true, nullptr, f);
// Some items may have been in use, so can't free.
if (count == 0) {
free(compilations);
Expand All @@ -176,15 +178,19 @@ class GPUCompilationCache {
}

template<typename CompileModuleT, typename... Args>
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr,
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr_ptr,
ContextT context, ModuleStateT &result,
CompileModuleT f,
Args... args) {
ScopedMutexLock lock_guard(&mutex);

uint32_t *id_ptr = (uint32_t *)state_ptr;
uintptr_t *id_ptr = (uintptr_t *)state_ptr_ptr;
if (*id_ptr == 0) {
*id_ptr = unique_id++;
if (unique_id == (uintptr_t)-1) {
// Sorry, out of ids
return false;
}
}

ModuleStateT *mod;
Expand All @@ -210,8 +216,10 @@ class GPUCompilationCache {
}

void release_hold(void *user_context, ContextT context, void *state_ptr) {
ScopedMutexLock lock_guard(&mutex);

ModuleStateT *mod;
uint32_t id = (uint32_t)(uintptr_t)state_ptr;
uintptr_t id = (uintptr_t)state_ptr;
bool result = find_internal(context, id, mod, -1);
halide_debug_assert(user_context, result); // Value must be in cache to be released
(void)result;
Expand Down
18 changes: 18 additions & 0 deletions src/runtime/printer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
#ifndef HALIDE_RUNTIME_PRINTER_H
#define HALIDE_RUNTIME_PRINTER_H

// This is useful for debugging threading issues in the Halide runtime:
// prefix all `debug()` statements with the thread id that did the logging.
// Left here (but disabled) for future reference.
#ifndef HALIDE_RUNTIME_PRINTER_LOG_THREADID
#define HALIDE_RUNTIME_PRINTER_LOG_THREADID 0
#endif

#if HALIDE_RUNTIME_PRINTER_LOG_THREADID
extern "C" int pthread_threadid_np(long thread, uint64_t *thread_id);
#endif

namespace Halide {
namespace Runtime {
namespace Internal {
Expand Down Expand Up @@ -51,6 +63,12 @@ class Printer {
// Pointers equal ensures no writes to buffer via formatting code
end = dst;
}

#if HALIDE_RUNTIME_PRINTER_LOG_THREADID
uint64_t tid;
pthread_threadid_np(0, &tid);
*this << "(TID:" << tid << ")";
#endif
}

// Not movable, not copyable
Expand Down

0 comments on commit 0673e67

Please sign in to comment.