Skip to content

Commit

Permalink
code and comment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed May 24, 2021
1 parent 8c2a628 commit 9600338
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions cupy/cuda/stream.pyx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda cimport stream as stream_module

import os
import threading
import weakref

from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda cimport stream as backends_stream

from cupy import _util


Expand Down Expand Up @@ -59,7 +59,7 @@ cdef class _ThreadLocal:
cdef int device_id = stream.device_id
if device_id == -1:
device_id = runtime.getDevice()
stream_module.set_current_stream_ptr(ptr, device_id)
backends_stream.set_current_stream_ptr(ptr, device_id)
self.current_stream_ref[device_id] = weakref.ref(stream)

cdef get_current_stream(self, int device_id=-1):
Expand All @@ -69,11 +69,11 @@ cdef class _ThreadLocal:
return stream_ref()

cdef intptr_t get_current_stream_ptr(self):
return stream_module.get_current_stream_ptr()
return backends_stream.get_current_stream_ptr()


cdef get_default_stream():
return Stream.ptds if stream_module.is_ptds_enabled() else Stream.null
return Stream.ptds if backends_stream.is_ptds_enabled() else Stream.null


cdef intptr_t get_current_stream_ptr():
Expand Down Expand Up @@ -188,7 +188,7 @@ cdef int check_stream_device_match(int device_id) except? -1:
return device_id


class BaseStream(object):
class _BaseStream:

"""CUDA stream.
Expand All @@ -199,12 +199,14 @@ class BaseStream(object):
"""

null = None
def __init__(self, ptr, device_id):
self.ptr = ptr
self.device_id = device_id

def __eq__(self, other):
# This operator is implemented to compare the singleton instance
# of null stream (Stream.null) can safely be compared with null
# stream instance created by a user.
# This operator needed as the ptr may be shared between multiple Stream
# instances (e.g, `Stream.null` singleton and `Stream(null=True)` or
# `ExternalStream`s).
return self.ptr == other.ptr

def __enter__(self):
Expand Down Expand Up @@ -316,7 +318,7 @@ class BaseStream(object):
runtime.streamWaitEvent(self.ptr, event.ptr)


class Stream(BaseStream):
class Stream(_BaseStream):

"""CUDA stream.
Expand Down Expand Up @@ -346,26 +348,29 @@ class Stream(BaseStream):
"""

null = None
ptds = None

def __init__(self, null=False, non_blocking=False, ptds=False):
if null:
# TODO(pentschev): move to streamLegacy. This wasn't possible
# because of a NCCL bug that should be fixed in the version
# following 2.8.3-1.
self.ptr = 0
self.device_id = -1
ptr = 0
device_id = -1
elif ptds:
if runtime._is_hip_environment:
raise ValueError('HIP does not support per-thread '
'default stream (ptds)')
self.ptr = runtime.streamPerThread
self.device_id = -1
ptr = runtime.streamPerThread
device_id = -1
elif non_blocking:
self.ptr = runtime.streamCreateWithFlags(
runtime.streamNonBlocking)
self.device_id = runtime.getDevice()
ptr = runtime.streamCreateWithFlags(runtime.streamNonBlocking)
device_id = runtime.getDevice()
else:
self.ptr = runtime.streamCreate()
self.device_id = runtime.getDevice()
ptr = runtime.streamCreate()
device_id = runtime.getDevice()
super().__init__(ptr, device)

def __del__(self, is_shutting_down=_util.is_shutting_down):
cdef intptr_t current_ptr
Expand All @@ -385,9 +390,9 @@ class Stream(BaseStream):
# because the memory would still be used in kernels executed in GPU.


class ExternalStream(BaseStream):
class ExternalStream(_BaseStream):

"""CUDA stream.
"""CUDA stream not managed by CuPy.
This class allows to use external streams in CuPy by providing the
stream pointer obtained from the CUDA runtime call.
Expand All @@ -411,14 +416,13 @@ class ExternalStream(BaseStream):
"""

def __init__(self, ptr, device_id=-1):
self.ptr = ptr
# It is in theory unsafe to just call runtime.getDevice() here, as the
# stream pointer could come from a different device (although
# unlikely). While we could use driver API combos cuStreamGetCtx ->
# cuCtxSetCurrent -> cuCtxGetDevice -> ... to retrieve the device ID
# associated with the stream, it is way too complicated. Let us keep
# this as thin as possible.
self.device_id = device_id
# associated with the stream, it is way too complicated and does not
# work with HIP. Let us keep this as thin as possible.
super().__init__(ptr, device_id)


Stream.null = Stream(null=True)
Expand Down

0 comments on commit 9600338

Please sign in to comment.