-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Device memory spill support #35
Changes from 4 commits
1aa68a8
7349974
377368a
197e653
de5c3e9
42c0686
c92bbc9
ce4ba37
b3cbf2c
855876c
07fccfa
6d48805
1dd2c16
df6cb95
5a7ceef
8659fb6
2c24f03
1d06c75
deff18f
2d9c150
ce5c650
c5fbb6f
dcc6a6a
9bc21e7
9eb2dfd
4324439
388c677
c942438
d968b0f
1abb1eb
eb70191
806ac8b
f35826e
f308943
1b9d38b
6d5b714
358f194
c81f763
f28cbd1
6163359
a3c89fb
8f59141
1e3acce
882bf31
1960b79
eb86d5d
8153b27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .local_cuda_cluster import LocalCUDACluster | ||
from . import config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import yaml | ||
import os | ||
|
||
import dask | ||
|
||
config = dask.config.config | ||
|
||
|
||
fn = os.path.join(os.path.dirname(__file__), 'dask_cuda.yaml') | ||
with open(fn) as f: | ||
dask_cuda_defaults = yaml.load(f) | ||
|
||
dask.config.update_defaults(dask_cuda_defaults) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
from tornado import gen | ||
from numba import cuda | ||
|
||
import dask | ||
from distributed import Worker | ||
from distributed.worker import logger | ||
from distributed.compatibility import unicode | ||
from distributed.utils import format_bytes, ignoring, parse_bytes, PeriodicCallback | ||
|
||
from .device_host_file import DeviceHostFile | ||
|
||
|
||
def get_device_total_memory(): | ||
""" Return total memory of CUDA device from current context """ | ||
return cuda.current_context().get_memory_info()[1] # (free, total) | ||
|
||
|
||
def get_device_used_memory(): | ||
""" Return used memory of CUDA device from current context """ | ||
memory_info = cuda.current_context().get_memory_info() # (free, total) | ||
return memory_info[1] - memory_info[0] | ||
|
||
|
||
def parse_device_memory_limit(memory_limit, ncores): | ||
""" Parse device memory limit input """ | ||
if memory_limit is None or memory_limit == 0 or memory_limit == 'auto': | ||
memory_limit = int(get_device_total_memory()) | ||
with ignoring(ValueError, TypeError): | ||
x = float(memory_limit) | ||
if isinstance(x, float) and x <= 1: | ||
return int(x * get_device_total_memory()) | ||
|
||
if isinstance(memory_limit, (unicode, str)): | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return parse_bytes(memory_limit) | ||
else: | ||
return int(memory_limit) | ||
|
||
|
||
class CUDAWorker(Worker): | ||
""" CUDA Worker node in a Dask distributed cluster | ||
|
||
Parameters | ||
---------- | ||
device_memory_limit: int, float, string | ||
Number of bytes of CUDA device memory that this worker should use. | ||
Set to zero for no limit or 'auto' for 100% of memory use. | ||
Use strings or numbers like 5GB or 5e9 | ||
device_memory_target_fraction: float | ||
Fraction of CUDA device memory to try to stay beneath | ||
device_memory_spill_fraction: float | ||
Fraction of CUDA device memory at which we start spilling to disk | ||
device_memory_pause_fraction: float | ||
Fraction of CUDA device memory at which we stop running new tasks | ||
|
||
Note: CUDAWorker is a subclass fo distributed.Worker, only parameters | ||
specific for CUDAWorker are listed here. For a complete list of | ||
parameters, refer to that. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
self.device_memory_limit = kwargs.pop('device_memory_limit') | ||
|
||
if 'device_memory_target_fraction' in kwargs: | ||
self.device_memory_target_fraction = kwargs.pop( | ||
'device_memory_target_fraction') | ||
else: | ||
self.device_memory_target_fraction = dask.config.get( | ||
'dask-cuda.worker.device-memory.target') | ||
if 'device_memory_spill_fraction' in kwargs: | ||
self.device_memory_spill_fraction = kwargs.pop( | ||
'device_memory_spill_fraction') | ||
else: | ||
self.device_memory_spill_fraction = dask.config.get( | ||
'dask_cuda.worker.device-memory.spill') | ||
if 'device_memory_pause_fraction' in kwargs: | ||
self.device_memory_pause_fraction = kwargs.pop( | ||
'device_memory_pause_fraction') | ||
else: | ||
self.device_memory_pause_fraction = dask.config.get( | ||
'dask_cuda.worker.device-memory.pause') | ||
|
||
super().__init__(**kwargs) | ||
|
||
self.device_memory_limit = parse_device_memory_limit( | ||
self.device_memory_limit, self.ncores) | ||
|
||
print('self.device_memory_target_fraction', self.device_memory_target_fraction) | ||
print('self.device_memory_spill_fraction', self.device_memory_spill_fraction) | ||
print('self.device_memory_pause_fraction', self.device_memory_pause_fraction) | ||
|
||
self.data = DeviceHostFile(device_memory_limit=self.device_memory_limit, | ||
memory_limit=self.memory_limit, | ||
local_dir=self.local_dir) | ||
|
||
self._paused = False | ||
self._device_paused = False | ||
|
||
if self.device_memory_limit: | ||
self._device_memory_monitoring = False | ||
pc = PeriodicCallback( | ||
self.device_memory_monitor, | ||
self.memory_monitor_interval * 1000, | ||
io_loop=self.io_loop | ||
) | ||
self.periodic_callbacks["device_memory"] = pc | ||
|
||
def _start(self, addr_on_port=0): | ||
super()._start(addr_on_port) | ||
if self.device_memory_limit: | ||
logger.info(' Device Memory: %26s', | ||
format_bytes(self.device_memory_limit)) | ||
logger.info('-' * 49) | ||
|
||
def _check_for_pause(self, fraction, pause_fraction, used_memory, memory_limit, | ||
paused, free_func, worker_description): | ||
if pause_fraction and fraction > pause_fraction: | ||
# Try to free some memory while in paused state | ||
if free_func: | ||
free_func() | ||
if not self._paused: | ||
logger.warning("%s is at %d%% memory usage. Pausing worker. " | ||
"Process memory: %s -- Worker memory limit: %s", | ||
worker_description, | ||
int(fraction * 100), | ||
format_bytes(used_memory), | ||
format_bytes(memory_limit)) | ||
return True | ||
elif paused: | ||
logger.warning("Worker is at %d%% memory usage. Resuming worker. " | ||
"Process memory: %s -- Worker memory limit: %s", | ||
int(fraction * 100), | ||
format_bytes(used_memory), | ||
format_bytes(memory_limit)) | ||
self.ensure_computing() | ||
return False | ||
|
||
@gen.coroutine | ||
def memory_monitor(self): | ||
""" Track this process's memory usage and act accordingly | ||
|
||
If we rise above (memory_spill_fraction * memory_limit) of | ||
memory use, start dumping data to disk. The default value for | ||
memory_spill_fraction is 0.7, defined via configuration | ||
'distributed.worker.memory.target'. | ||
|
||
If we rise above (memory_pause_fraction * memory_limit) of | ||
memory use , stop execution of new tasks. The default value | ||
for memory_pause_fraction is 0.8, defined via configuration | ||
'distributed.worker.memory.pause'. | ||
""" | ||
if self._memory_monitoring: | ||
return | ||
self._memory_monitoring = True | ||
total = 0 | ||
|
||
proc = self.monitor.proc | ||
memory = proc.memory_info().rss | ||
frac = memory / self.memory_limit | ||
|
||
# Pause worker threads if device memory use above | ||
# (self.memory_pause_fraction * 100)% | ||
self._paused = self._check_for_pause(frac, self.memory_pause_fraction, memory, | ||
self.memory_limit, self._paused, | ||
self._throttled_gc.collect(), 'Worker') | ||
self.paused = (self._paused or self._device_paused) | ||
|
||
# Dump data to disk if memory use above | ||
# (self.memory_spill_fraction * 100)% | ||
if self.memory_spill_fraction and frac > self.memory_spill_fraction: | ||
target = self.memory_limit * self.memory_target_fraction | ||
count = 0 | ||
need = memory - target | ||
while memory > target: | ||
if not self.data.host.fast: | ||
logger.warning("Memory use is high but worker has no data " | ||
"to store to disk. Perhaps some other process " | ||
"is leaking memory? Process memory: %s -- " | ||
"Worker memory limit: %s", | ||
format_bytes(proc.memory_info().rss), | ||
format_bytes(self.memory_limit)) | ||
break | ||
k, v, weight = self.data.host.fast.evict() | ||
del k, v | ||
total += weight | ||
count += 1 | ||
yield gen.moment | ||
memory = proc.memory_info().rss | ||
if total > need and memory > target: | ||
# Issue a GC to ensure that the evicted data is actually | ||
# freed from memory and taken into account by the monitor | ||
# before trying to evict even more data. | ||
self._throttled_gc.collect() | ||
memory = proc.memory_info().rss | ||
if count: | ||
logger.debug("Moved %d pieces of data and %s bytes to disk", | ||
count, format_bytes(total)) | ||
|
||
self._memory_monitoring = False | ||
raise gen.Return(total) | ||
|
||
@gen.coroutine | ||
def device_memory_monitor(self): | ||
""" Track this process's memory usage and act accordingly | ||
|
||
If we rise above (device_memory_spill_fraction * memory_limit) of | ||
device memory use, start dumping data to disk. The default value | ||
for device_memory_spill_fraction is 0.7, defined via configuration | ||
'dask-cuda.worker.device-memory.target'. | ||
|
||
If we rise above (device_memory_pause_fraction * memory_limit) of | ||
device memory use, stop execution of new tasks. The default value | ||
for device_memory_pause_fraction is 0.8, defined via configuration | ||
'dask-cuda.worker.device-memory.pause'. | ||
""" | ||
if self._memory_monitoring: | ||
return | ||
self._device_memory_monitoring = True | ||
total = 0 | ||
memory = get_device_used_memory() | ||
frac = memory / self.device_memory_limit | ||
|
||
# Pause worker threads if device memory use above | ||
# (self.device_memory_pause_fraction * 100)% | ||
self._paused = self._check_for_pause(frac, self.device_memory_pause_fraction, | ||
memory, self.device_memory_limit, | ||
self._device_paused, None, | ||
"Worker's CUDA device") | ||
self.paused = (self._paused or self._device_paused) | ||
|
||
# Dump device data to host if device memory use above | ||
# (self.device_memory_spill_fraction * 100)% | ||
if (self.device_memory_spill_fraction | ||
and frac > self.device_memory_spill_fraction): | ||
target = self.device_memory_limit * self.device_memory_target_fraction | ||
count = 0 | ||
while memory > target: | ||
if not self.data.device.fast: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question regarding , what is the data that is stored in I was trying to persist a data frame bigger than the gpu memory with individual chunks that comfortably fit in device memory but it seems to get paused and throw the following warning on both workers and get paused: distributed.worker - WARNING - CUDA device memory use is high but worker has no data to store to host. Perhaps some other process is leaking memory? Process memory: 13.76 GB -- Worker memory limit: 17.07 GB Also, is there a way to access the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
IMO, this is one of the trickiest parts. We can't guarantee that
I think this isn't possible, why would you like to access that directly? I don't know if this is by design or is just something that was never implemented, @mrocklin could you clarify this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct. The workers are in separate processes, so there is no way to access them from Python. You can ask Dask to run functions on them to inspect state if you like with def f(dask_worker):
return len(dask_worker.data.device.fast)
client.run(f) (See the run docstring for more information). You could also try the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that when @pentschev says 'dask.array device chunks" he also means any piece of GPU allocated data, which could be a dask array device chunk as he's dealing with, or a cudf dataframe as you're dealing with. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot @mrocklin for the function to inspect state. I used your function to debug whats happening. It seems that the data is being evicted from ('tcp://172.17.0.2:44914', {'data.device.fast': 14, 'data.device.slow': 0})
('tcp://172.17.0.2:44914', {'data.device.fast': 15, 'data.device.slow': 0})
('tcp://172.17.0.2:44914', {'data.device.fast': 3, 'data.device.slow': 14})
('tcp://172.17.0.2:44914', {'data.device.fast': 3, 'data.device.slow': 14})
('tcp://172.17.0.2:44914', {'data.device.fast': 0, 'data.device.slow': 18})
('tcp://172.17.0.2:44914', {'data.device.fast': 0, 'data.device.slow': 18}) Error:
May be the del here is not clearing memory. Don't know. https://github.com/rapidsai/dask-cuda/pull/35/files#diff-c87f0866b277f959dc7c5d1e4b0ff015R243 Will add a small minimal reproducible example here soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that this is a major issue, that's why I'm concerned with it. In particular, I think pausing is something that can't be enabled under these circumstances, if it is, then when Dask spills memory to host but can't really release that, the worked will get stuck. I have two proposals (not necessarily mutually exclusive) until we come up with something else:
We can also disable pausing by default, which I'm inclined to think should be the default to prevent this sort of situation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For 1 see dask/distributed#2453 Disable pausing by default seems fine to me. This is just a config value change at this point, yes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, that would be something similar, if not exactly that (sorry, I can't understand all the details without diving in a bit deeper). Any thoughts on item 2. as well?
Yes. I just don't know if disabling pause has no other consequences. On the host, I guess this is to prevent the host from running out of memory and eventually getting killed, is that right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I can see how it would solve the problem. I guess I'm hoping that medium term it's not necessary. My inclination is to wait until we have a real-world problem that needs this before adding it. I won't be surprised if that problem occurs quickly, but I'd still rather put it off and get this in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No objections from me. That said, I have no further changes to be added, from my side, it's ready for more reviews or merging. |
||
logger.warning("CUDA device memory use is high but worker has " | ||
"no data to store to host. Perhaps some other " | ||
"process is leaking memory? Process memory: " | ||
"%s -- Worker memory limit: %s", | ||
format_bytes(get_device_used_memory()), | ||
format_bytes(self.device_memory_limit)) | ||
break | ||
k, v, weight = self.data.device.fast.evict() | ||
del k, v | ||
total += weight | ||
count += 1 | ||
yield gen.moment | ||
memory = get_device_used_memory() | ||
if count: | ||
logger.debug("Moved %d pieces of data and %s bytes to host memory", | ||
count, format_bytes(total)) | ||
|
||
self._device_memory_monitoring = False | ||
raise gen.Return(total) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
dask-cuda: | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
worker: | ||
# Fractions of device memory at which we take action to avoid memory blowup | ||
# Set any of the lower three values to False to turn off the behavior entirely | ||
device-memory: | ||
target: 0.60 # target fraction to stay below | ||
spill: 0.70 # fraction at which we spill to host | ||
pause: 0.80 # fraction at which we pause worker threads |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from sys import exit | ||
|
||
import click | ||
from distributed import Nanny, Worker | ||
from distributed import Nanny | ||
from distributed.config import config | ||
from distributed.utils import get_ip_interface, parse_timedelta | ||
from distributed.worker import _ncores | ||
|
@@ -23,6 +23,7 @@ | |
enable_proctitle_on_current, | ||
) | ||
|
||
from .cuda_worker import CUDAWorker | ||
from .local_cuda_cluster import cuda_visible_devices | ||
from .utils import get_n_gpus | ||
|
||
|
@@ -98,6 +99,15 @@ | |
"string (like 5GB or 5000M), " | ||
"'auto', or zero for no memory management", | ||
) | ||
@click.option( | ||
"--device-memory-limit", | ||
default="auto", | ||
help="Bytes of memory per CUDA device that the worker can use. " | ||
"This can be an integer (bytes), " | ||
"float (fraction of total system memory), " | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"string (like 5GB or 5000M), " | ||
"'auto', or zero for no memory management", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to avoid this until someone asks for it if possible. I'm somewhat against an API like this because it forces us to enumerate the possible options in code. If we were to do something like this I think that we would probably provide the full namespace of the class and then try to import it. However there is enough uncertainty here that, for maintenance reasons, I'd prefer that we not promise anything until someone is asking us for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For context, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's not true. We solely disabled pausing the worker, to control spilling from the device, we still need to monitor the device memory. And this is why I needed to subclass it.
Unfortunately, this is necessary for us to reenable the old There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why do we need to monitor device memory externally from the use of Dask workers operated this way for a long time before we finally needed to give in and have a separate periodic callback that tracked system memory. I'm inclined to try the simpler approach first and see if it breaks down or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, they do track how much memory they take. However, tracking the device memory lets us decide when it's time to spill memory. Isn't that what The block https://github.com/rapidsai/dask-cuda/pull/35/files#diff-a77f0c6f19d8d34d59aede5e31455719R282 controls the spilling, and this is why we needed to subclass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, something like that diff. You'll also want to add the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation on memory monitor.
Yes, and I also want to create the object before. :) But ok, I can probably have it quickly done by tomorrow. There's a few more things that need to be ported to allow it to work (like finding out how much memory the device has in total), and also some test(s), which shouldn't be too difficult now that there's already one that works with the monitoring mechanism and I can base it on that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's actually valid to pass just the class. Dask will construct it. I think that this is explained in the
I recommend that we start with just using the full memory or a config value by default and not mess with any user inputs (which will get messy). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok, I'll check that.
We need to identify how much memory there is available for the device, regardless. I can probably use the same numba code from before. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, mostly I want to say lets not add a new |
||
@click.option( | ||
"--reconnect/--no-reconnect", | ||
default=True, | ||
|
@@ -146,6 +156,7 @@ def main( | |
nthreads, | ||
name, | ||
memory_limit, | ||
device_memory_limit, | ||
pid_file, | ||
reconnect, | ||
resources, | ||
|
@@ -243,6 +254,7 @@ def del_pid_file(): | |
loop=loop, | ||
resources=resources, | ||
memory_limit=memory_limit, | ||
device_memory_limit=device_memory_limit, | ||
reconnect=reconnect, | ||
local_dir=local_directory, | ||
death_timeout=death_timeout, | ||
|
@@ -252,6 +264,7 @@ def del_pid_file(): | |
contact_address=None, | ||
env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)}, | ||
name=name if nprocs == 1 or not name else name + "-" + str(i), | ||
worker_class=CUDAWorker, | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs | ||
) | ||
for i in range(nprocs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from zict import Buffer, File, Func | ||
from zict.common import ZictBase | ||
from distributed.protocol import deserialize_bytes, serialize_bytelist | ||
from distributed.worker import weight | ||
|
||
try: | ||
from cytoolz import partial | ||
except ImportError: | ||
from toolz import partial | ||
|
||
import os | ||
|
||
|
||
def _is_device_object(obj): | ||
return hasattr(obj, '__cuda_array_interface__') | ||
|
||
|
||
class DeviceHostFile(ZictBase): | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, device_memory_limit=None, memory_limit=None, | ||
local_dir='dask-worker-space', compress=False): | ||
path = os.path.join(local_dir, 'storage') | ||
|
||
self.device_func = dict() | ||
self.host_func = dict() | ||
self.disk_func = Func(partial(serialize_bytelist, on_error='raise'), | ||
deserialize_bytes, File(path)) | ||
|
||
self.host = Buffer(self.host_func, self.disk_func, memory_limit, | ||
weight=weight) | ||
self.device = Buffer(self.device_func, self.host, device_memory_limit, | ||
weight=weight) | ||
|
||
self.fast = self.host.fast | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking a bit about the names here. My first reaction was to expect As a result, we might consider a renaming, something like self.device = dict()
self.host = dict()
self.disk = Func(..., File(...)) And then we might rename the buffers to something else (I can't think of a good name right now). This makes it somewhat pleasant when inspecting this object, to see what is on the device and host, and so on. One can do this now, but only if you know how to dive within There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried doing this in aa0ddd4 but wasn't able to make things clean before I had to context switch away. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure those changes make things much clearer. FWIW, I'd normally prefer to document what such attributes mean, from a user's perspective, I'm not sure how much clearer things get from renaming Also, is there a use case for users to access that data directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, people often want to ask things like "how much of my data is currently on the device vs on the host?" They'll run Dask functions that let them query the worker state directly (this is common among advanced users). With the current naming they might do something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we should then provide wrappers to that information only? e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's my primary objective in the renaming. I mostly want the names to be easily interprettable for novice users who run into this object quickly without knowing much about the object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'd like to finish that yourself? Otherwise I can do that tomorrow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope. Other things have taken precedence. I thought it might be a quick fix, so thought it'd be faster to just do it. I ended up being mistaken :) Also happy with other naming schemes if you prefer. Mostly I want to avoid having to explain how to dive into these objects to users in the future. If documentation is the best way to do that, I'm happy with that as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will do it tomorrow then. I'm fine with the naming scheme, as long as we can keep the external names simple ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pushed the changes we discussed yesterday, let me know if you agree with the naming scheme. |
||
|
||
def __setitem__(self, key, value): | ||
if _is_device_object(value): | ||
self.device[key] = value | ||
else: | ||
self.host[key] = value | ||
pentschev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __getitem__(self, key): | ||
if key in self.device: | ||
return self.device[key] | ||
else: | ||
raise KeyError | ||
|
||
def __len__(self): | ||
return self.device.__len__() | ||
mrocklin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __iter__(self): | ||
return self.device.__iter__() | ||
mrocklin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __delitem__(self, i): | ||
return self.device.__delitem__(i) | ||
mrocklin marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I didn't know that Numba could do this.
Also, for general awareness, there is also this: https://github.com/gpuopenanalytics/pynvml/ though I haven't used it much myself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't know about that one. I don't have a strong opinion, but since Numba is a package more likely to be already installed and we only check memory for the time being, I think it makes more sense to just use Numba for now.