Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Means of zarr arrays cause a memory overload in dask workers #6709

Closed
robin-cls opened this issue Jun 20, 2022 · 17 comments
Closed

Means of zarr arrays cause a memory overload in dask workers #6709

robin-cls opened this issue Jun 20, 2022 · 17 comments
Labels
topic-dask topic-zarr Related to zarr storage library upstream issue

Comments

@robin-cls
Copy link

robin-cls commented Jun 20, 2022

What is your issue?

Hello everyone !

I am submitting this issue here but it is not entirely clear if my problem comes from xarray, dask or zarr.

The goal here is to compute a mean from the GCM anomalies of SSH. The following simple code creates an artificial dataset (a variable is about 90G) with the anomaly fields, and compute the cross-products means.

import dask.array as da
import numpy as np
import xarray as xr

ds = xr.Dataset(
    dict(
        anom_u=(["time", "face", "j", "i"], da.ones((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
        anom_v=(["time", "face", "j", "i"], da.ones((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
    )
)

ds["anom_uu_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data**2, axis=0))
ds["anom_vv_mean"] = (["face", "j", "i"], np.mean(ds.anom_v.data**2, axis=0))
ds["anom_uv_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data * ds.anom_v.data, axis=0))

ds[["anom_uu_mean", "anom_vv_mean", "anom_uv_mean"]].compute()

I was expecting a low memory usage because after using a single chunk of anom_u and anom_v to do a mean iteration, these two could be forgotten. The following figure checks that we are very low on memory usage so all is well.

image

The matter becomes more complicated when the dataset is opened from a ZARR store. We simply dumped our previous articially generated data to a temporary store, and reloaded it :

import dask.array as da
import numpy as np
import xarray as xr

ds = xr.Dataset(
    dict(
        anom_u=(["time", "face", "j", "i"], da.ones((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
        anom_v=(["time", "face", "j", "i"], da.ones((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
    )
)

store = "/work/scratch/test_zarr_graph"
ds.to_zarr(store, compute=False, mode="a")
ds = xr.open_zarr(store)

ds["anom_uu_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data**2, axis=0))
ds["anom_vv_mean"] = (["face", "j", "i"], np.mean(ds.anom_v.data**2, axis=0))
ds["anom_uv_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data * ds.anom_v.data, axis=0))

ds[["anom_uu_mean", "anom_vv_mean", "anom_uv_mean"]].compute()

image

I was expecting a similar behavior between a dataset created from scratch and one created from a zarr store, but it seems not to be the case. I tried using inline_array=True with xr.open_dataset but to no avail. I also tried computing 2 variables instead of 3 and it works properly, so the behavior seems strange to me.

Do you see any reason as to why I am seeing such memory load on my workers ?

Here are the software version I use :
xarray version : 2022.6.0rc0
dask version : 2022.04.1
zarr version : 2.11.1
numpy version : 1.21.6

@robin-cls robin-cls added the needs triage Issue that has not been reviewed by xarray team member label Jun 20, 2022
@TomNicholas TomNicholas added topic-dask topic-zarr Related to zarr storage library and removed needs triage Issue that has not been reviewed by xarray team member labels Jun 21, 2022
@TomNicholas
Copy link
Member

Hi @robin-cls - can we see the dask graphs (or a representative subset of them) for these two cases?

@robin-cls
Copy link
Author

robin-cls commented Jun 22, 2022

Hi @TomNicholas

I've reduced the original dataset to 11 chunks over the time dimension so that we can see the graph properly. I also replaced the .compute operation by a to_zarr(compute=False) because I don't know how to visualize xarray operations without generating a Delayed object (comments are welcomed on this point !)

Anyway here are the files, first one is the graph where the means are built from dask.ones arrays

graph_no_zarr_source

Second one is the graph where the means are built from the same arrays but opened from a zarr store

graph_zarr_source

I am quite a newbie in dask graphs debug but everything seems ok in the second graph, apart from the open_dataset tasks that are linked to a parent task. Also, I noticed that Dask have fused the 'ones' operations in the first graph. Would it help if I generated another arrays with zeros instead ?

@dcherian
Copy link
Contributor

You can use dask.visualize(xarray_object)

You could try passing inline_array=True to open_zarr

@robin-cls
Copy link
Author

robin-cls commented Jun 23, 2022

Thanks for the tips, I was investigating inline_array=True and still no luck. The graph seems OK though. I can attach it if you want but I think zarr is not the culprit.

Here is why :

In the first case, where we build the array from scratch, the ones array is simple. Dask seems to understand that it does not have to make many copies of it. So when replacing ones with random data, we observe the same behavior as opening the dataset from a ZARR store (high memory usage on a worker) :

import dask.array as da
import numpy as np
import xarray as xr

ds = xr.Dataset(
    dict(
        anom_u=(["time", "face", "j", "i"], da.random.random((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
        anom_v=(["time", "face", "j", "i"], da.random.random((10311, 1, 987, 1920), chunks=(10, 1, 987, 1920))),
    )
)

ds["anom_uu_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data**2, axis=0))
ds["anom_vv_mean"] = (["face", "j", "i"], np.mean(ds.anom_v.data**2, axis=0))
ds["anom_uv_mean"] = (["face", "j", "i"], np.mean(ds.anom_u.data * ds.anom_v.data, axis=0))

ds[["anom_uu_mean", "anom_vv_mean", "anom_uv_mean"]].compute()

I think the question now is why Dask must load so many data when doing my operation :

graph

If we take the computation graph (I've put the non optimized version), my understanding is that we could do the following :

  • Load the first chunk of anom_u
  • Load the second chunk of anom_v
  • Do the multiplication anom_u*anom_v, anom_u**, anom_v ** 2
  • Do the mean-chunk task
  • Unload all the previous tasks
  • Redo the same and combine the mean-chunks tasks

For information, one chunk is about 1.4G, so I expect see peaks of 5*1.4 = 7G in memory (plus what's needed to store the mean_chunk), but I instead got 15G+ in my worker, most of it taken by the random-samples

image

Is my understanding of distributed mean wrong ? Why are the random-sample not flushed ?

@dcherian
Copy link
Contributor

dcherian commented Jun 23, 2022

This looks like the classic distributed scheduling issue that might get addressed soon: dask/distributed#6560

@gjoseph92 the MVCE in #6709 (comment) might be a useful benchmark / regression test (ds.u.mean(), ds.v.mean(), (ds.u * ds.v).mean() all together). This kind of dependency -- variables and combinations of multiple variable being computed together -- turns up a lot in climate model workflows).

@gjoseph92
Copy link

Thanks @dcherian, yeah this is definitely root task overproduction. I think your case is somewhat similar to @TomNicholas's dask/distributed#6571 (that one might even be a little simpler actually).

There's some prototyping going on to address this, but I'd say "soon" is probably on the couple month timescale right now FYI.

dask/distributed#6598 or dask/distributed#6614 will probably make this work. I'm hopefully going to benchmark these against some real workloads in the next couple days, so I'll probably add yours in. Thanks for the MVCE!

Is my understanding of distributed mean wrong ? Why are the random-sample not flushed?

See dask/distributed#6360 (comment) and the linked issues for why this happens.

@dcherian
Copy link
Contributor

Thanks @gjoseph92 I think there is a small but important increase in complexity here because we do ds.u.mean(), ds.v.mean(), (ds.u*ds.v).mean() all together so each chunk of ds.u and ds.v is used for two different outputs.

IIUC the example in dask/distributed#6571 is basically (ds.u * ds.v).mean() purely.

@robin-cls
Copy link
Author

Thanks @gjoseph92 and @dcherian . I'll try the different approaches in the links you have provided to see if I can improve my current solution (I compute the fields separately which means more IO and more operations)

@gjoseph92
Copy link

FYI @robin-cls I would be a bit surprised if there is anything you can do on your end to fix things here with off-the-shelf dask. What @dcherian mentioned in dask/distributed#6360 (comment) is probably the only thing that might work. Otherwise you'll need to run one my experimental branches.

@gjoseph92
Copy link

I took a little bit more of a look at this and I don't think root task overproduction is the (only) problem here.

I also feel like intuitively, this operation shouldn't require holding so many root tasks around at once. But the graph dask is making, or how it's ordering it, doesn't seem to work that way. We can see the ordering is pretty bad:

multi-mean

When we actually run it (on dask/distributed#6614 with overproduction fixed), you can see that dask requires keeping tons of the input chunks in memory, because they're going to be needed by a future task that isn't able to run yet (because not all of its inputs have been computed):

Screen Shot 2022-06-23 at 5 10 19 PM

I feel like it's possible that the order in which dask is executing the input tasks is bad? But I more thank that I haven't thought about the problem enough, and there's an obvious reason why the graph is structured like this.

norlandrhagen added a commit to carbonplan/cmip6-downscaling that referenced this issue Jun 24, 2022
norlandrhagen added a commit to carbonplan/cmip6-downscaling that referenced this issue Jun 30, 2022
* as committed here ran a test case for western US using three features to
predict precip - a couple remaining TODOs but functional at this point

* allow access to wind and standardize names/coordinates

* tweaks for multivariate run

* sample config for multivariate run

* update config files for multivariate and boolean in detrend

* variable switching

* param configs with wind

* split get_gcm by variable due to this issue with dask: pydata/xarray#6709

* switched order of postprocess and added method = nearest to .sel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: norland r hagen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@fjetter
Copy link
Contributor

fjetter commented Jun 29, 2023

xref dask/dask#10384 I could track down the suspicion of @gjoseph92 that this is a dask order issue. This is somehow introduced because xarray is running multiple reductions on the same source data (dask should be robust to this, just pointing it out)

This example should work trivially when either is true:

  1. Only one of the arrays is calculated at the same time
  2. The xarray dataset is transformed to a dask.DataFrame using mean.to_dask_dataframe() (The DataFrame graph looks slightly different and is handled well by dask.order)

@fjetter
Copy link
Contributor

fjetter commented Sep 28, 2023

Just a heads up. I'm working for a fix for this in dask/dask, see dask/dask#10535

Preliminary results look very promising

image

This graph show the memory usage for a couple of runs with increasing size in the time dimension using the example as posted in pangeo-data/distributed-array-examples#2 (comment)

image
This was far away from the spilling threshold (yellow line) so the constant memory was indeed due to better scheduling, not spilling or anything like that.

I'm also looking at other workloads. If you are aware of other stuff that should be constant or near-constant in memory usage but isn't, please let me know!

@dcherian
Copy link
Contributor

I'm also looking at other workloads. If you are aware of other stuff that should be constant or near-constant in memory usage but isn't, please let me know!

There is this one: dask/distributed#7274 (comment) / dask/distributed#8005 that I didn't manage to make a reproducible example for.

@fjetter
Copy link
Contributor

fjetter commented Sep 29, 2023

There is this one:

I remember, thank you. I'll have a look at it next

@dcherian
Copy link
Contributor

dcherian commented Oct 4, 2023

(responding to your email here)

Does something like this look bad to you?

import dask.array

big_array = dask.array.random.random((1e5, 3e3, 3e3), chunks=(20, 1000, 1000))
small_array_1 = dask.array.random.random((1, 3e3, 3e3), chunks=1000) # more than 5 chunks, so not root-ish
small_array_2 = dask.array.random.random((1, 3e3, 3e3), chunks=1000) # more than 5 chunks, so not root-ish
result = (big_array * small_array_1 * small_array_2).mean(axis=(-1, -2))

@fjetter
Copy link
Contributor

fjetter commented Oct 4, 2023

Does something like this look bad to you?

Doesn't look bad at all. Nothing here is detected as root-ish but it is running very smoothly on my machine (both using main and my new ordering PR)

@dcherian
Copy link
Contributor

dcherian commented Oct 9, 2023

Closing since dask/dask#10535 will be merged soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-dask topic-zarr Related to zarr storage library upstream issue
Projects
None yet
Development

No branches or pull requests

5 participants