Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive memory use in fold-style reductions #7552

Open
gjoseph92 opened this issue Feb 16, 2023 · 2 comments
Open

Excessive memory use in fold-style reductions #7552

gjoseph92 opened this issue Feb 16, 2023 · 2 comments

Comments

@gjoseph92
Copy link
Collaborator

This graph is not parallel. It's an incremental, serial reduction. Each reducer requires the previous reducer to finish before it can run. I've set up the tasks so that reducers are significantly slower than data producers.

Therefore, there's no need to load all the inputs into memory up front. It's going to be a long time until the final input task can be used. If we load it right away, it'll just take up memory.

Screen Shot 2023-02-15 at 7 49 44 PM

As you can see, even though the load tasks were queued, far more data was loaded into memory than we can process at once.

Screen Shot 2023-02-15 at 7 49 50 PM

With larger data sizes, or if there was some other computation going on at the same time, this probably could have killed the cluster.

This was motivated by playing around with dask-ml and incremental training. AFAIU point of incremental training is to be able to train on a larger-than-memory dataset by training on it chunk-by-chunk. But it seems this scheduling behavior might defeat the purpose, since all the data will end up loaded into distributed memory anyway (as long as training is slower than data loading; quite possible with a big ML model). Hopefully spilling will save you in the real-world, but it still doesn't seem like great behavior.

No ideas yet how to address this; just interesting to think about in the context of other scheduling questions like #7531

Minimal reproducer:

import time

import dask
import distributed
from dask.utils import parse_bytes
import distributed


@dask.delayed(pure=False)
def load():
    return "x" * parse_bytes("50MB")


@dask.delayed()
def fit(prev, data):
    time.sleep(1)
    return prev + 1


roots = [load() for _ in range(50)]
prev = fit(0, roots[0])
for r in roots[1:]:
    prev = fit(prev, r)


if __name__ == "__main__":
    with distributed.Client(
        n_workers=4, threads_per_worker=1, memory_limit="1 GiB"
    ) as client:
        prev.compute()

Dask-ml example:

# model.py
import time
from torch import nn


class MyModel(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)

    def forward(self, X, **kwargs):
        time.sleep(5)
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.output(X)
        return X
import distributed
import numpy as np
from dask_ml.datasets import make_classification
from dask_ml.wrappers import Incremental
from torch import nn

from skorch import NeuralNetClassifier

from model import MyModel


X, y = make_classification(
    100_000, 20, n_informative=10, random_state=0, chunks=(10000, 20)
)
X = X.astype(np.float32)
y = y.astype(np.int64)


niceties = {
    "callbacks": False,
    "warm_start": False,
    "train_split": None,
    "max_epochs": 1,
}

net = NeuralNetClassifier(
    MyModel,
    criterion=nn.CrossEntropyLoss(),
    lr=0.1,
    **niceties,
)

model = Incremental(net, scoring="accuracy")


if __name__ == "__main__":
    with distributed.Client() as client:
        model.fit(X, y)
@TomAugspurger
Copy link
Member

xref dask/dask-ml#765, which describes this too.

@gjoseph92
Copy link
Collaborator Author

Yup, same situation—thanks for the link @TomAugspurger.

Just to also note that as expected, the ordering for the graph is good:
mydask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants