Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Windows] Memory error with model.predict() on TabPFNRegressor #100

Open
realshaktigupta opened this issue Jan 9, 2025 · 13 comments
Open
Labels
documentation Improvements or additions to documentation

Comments

@realshaktigupta
Copy link

I run the following code.

# Train the model
model = TabPFNRegressor()
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)

This leads to the following error.

AttributeError: module 'os' has no attribute 'sysconf'

Stack Trace:

AttributeError                            Traceback (most recent call last)
Cell In[110], line 9
      6 model.fit(X_train, y_train)
      8 # Make predictions
----> 9 y_pred = model.predict(X_test)

File ~\miniconda3\envs\TabPFN\lib\site-packages\tabpfn\regressor.py:624, in TabPFNRegressor.predict(self, X, output_type, quantiles)
    621 outputs: list[torch.Tensor] = []
    622 borders: list[np.ndarray] = []
--> 624 for output, config in self.executor_.iter_outputs(
    625     X,
    626     device=self.device_,
    627     autocast=self.use_autocast_,
    628 ):
    629     assert isinstance(config, RegressorEnsembleConfig)
    631     if self.softmax_temperature != 1:

File ~\miniconda3\envs\TabPFN\lib\site-packages\tabpfn\inference.py:311, in InferenceEngineCachePreprocessing.iter_outputs(self, X, device, autocast)
    308     X_full = X_full.type(self.force_inference_dtype)
    309     y_train = y_train.type(self.force_inference_dtype)  # type: ignore # noqa: PLW2901
--> 311 MemoryUsageEstimator.reset_peak_memory_if_required(
    312     save_peak_mem=self.save_peak_mem,
    313     model=self.model,
    314     X=X_full,
    315     cache_kv=False,
    316     device=device,
    317     dtype_byte_size=self.dtype_byte_size,
    318     safety_factor=1.2,  # TODO(Arjun): make customizable
    319 )
    321 style = None
    323 with (
    324     torch.autocast(device.type, enabled=autocast),
    325     torch.inference_mode(),
    326 ):

File ~\miniconda3\envs\TabPFN\lib\site-packages\tabpfn\model\memory.py:372, in MemoryUsageEstimator.reset_peak_memory_if_required(cls, save_peak_mem, model, X, cache_kv, device, dtype_byte_size, safety_factor, n_train_samples)
    367 save_peak_mem_is_num = isinstance(
    368     save_peak_mem,
    369     (float, int),
    370 ) and not isinstance(save_peak_mem, bool)
    371 if save_peak_mem == "auto" or save_peak_mem_is_num:
--> 372     memory_available_after_batch = cls.estimate_memory_remainder_after_batch(
    373         X,
    374         model,
    375         cache_kv=cache_kv,
    376         device=device,
    377         dtype_byte_size=dtype_byte_size,
    378         safety_factor=safety_factor,
    379         n_train_samples=n_train_samples,
    380         max_free_mem=save_peak_mem
    381         if isinstance(save_peak_mem, (float, int))
    382         else None,
    383     )
    384     save_peak_mem = memory_available_after_batch < 0
    386 if save_peak_mem:

File ~\miniconda3\envs\TabPFN\lib\site-packages\tabpfn\model\memory.py:318, in MemoryUsageEstimator.estimate_memory_remainder_after_batch(cls, X, model, cache_kv, device, dtype_byte_size, safety_factor, n_train_samples, max_free_mem)
    302 """Whether to save peak memory or not.
    303 
    304 Args:
   (...)
    315     The amount of free memory available after a batch is computed.
    316 """
    317 if max_free_mem is None:
--> 318     max_free_mem = cls.get_max_free_memory(
    319         device,
    320         unit="gb",
    321         default_gb_cpu_if_failed_to_calculate=DEFAULT_CPU_MEMORY_GB_IF_NOT_CUDA,
    322     )
    324 mem_per_batch = cls.estimate_memory_of_one_batch(
    325     X,
    326     model,
   (...)
    330     n_train_samples=n_train_samples,
    331 )
    333 return max_free_mem - (mem_per_batch * safety_factor)

File ~\miniconda3\envs\TabPFN\lib\site-packages\tabpfn\model\memory.py:252, in MemoryUsageEstimator.get_max_free_memory(cls, device, unit, default_gb_cpu_if_failed_to_calculate)
    249 if device.type.startswith("cpu"):
    250     try:
    251         free_memory = (
--> 252             os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") / 1e9
    253         )
    254     except ValueError:
    255         warnings.warn(
    256             "Could not get system memory, defaulting to"
    257             f" {default_gb_cpu_if_failed_to_calculate} GB",
    258             RuntimeWarning,
    259             stacklevel=2,
    260         )

AttributeError: module 'os' has no attribute 'sysconf'
@mmschlk
Copy link

mmschlk commented Jan 9, 2025

I see the same with Python version 3.10.11 on Windows. Increasing or decreasing the train or test size seems to have no effect.

import tabpfn

from importlib.metadata import version
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing

print("tabpfn version: ", version('tabpfn'))

x_data, y_data = fetch_california_housing(return_X_y=True)

x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, train_size=10000, random_state=42)

print("Train data shape: ", x_train.shape, y_train.shape)
print("Test data shape: ", x_test.shape, y_test.shape)

# model
model = tabpfn.TabPFNRegressor()
model.fit(x_train, y_train)

# prediction
predictions = model.predict(x_test[0:10])
mse = mean_squared_error(y_test[0:10], predictions[0:10])

>> tabpfn version:  2.0.0
>> Train data shape:  (10000, 8) (10000,)
>> Test data shape:  (10640, 8) (10640,)
>> ...
>> AttributeError: module 'os' has no attribute 'sysconf'

@eddiebergman
Copy link
Collaborator

I have a PR in #103

Will raise a seperate issue to properly deal with this.

@eddiebergman
Copy link
Collaborator

eddiebergman commented Jan 9, 2025

Fixed in main branch now, not sure when a hotfix path comes out, I would hope today. @noahho this should be closed once you do a release

@LennartPurucker LennartPurucker changed the title Running model.predict() on TabPFNRegressor gives error [Windows] Running model.predict() on TabPFNRegressor gives error Jan 9, 2025
@LennartPurucker LennartPurucker added the bug Something isn't working label Jan 9, 2025
@LennartPurucker
Copy link
Collaborator

If you want to try out the fix now, you can install it from GitHub with the following:

git clone https://github.com/PriorLabs/TabPFN.git
pip install -e tabpfn

If you can try it out, let me know if it works for you. Thank you!

@mmschlk
Copy link

mmschlk commented Jan 9, 2025

Thanks @LennartPurucker for the quick merge. I am using the fix and it works! FYI, my initial script takes quite some time though on my laptop selecting a smaller train size helps. :)

@realshaktigupta
Copy link
Author

Thanks @LennartPurucker for the fix. The sysconf related error is fixed, however I am still getting memory allocation related errors.

It tries to allocate a very large amount of memory during calculating multi-headed attention output.

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 54629649408 bytes.

Stack Trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 10
      5 # Train the model
      6 model = TabPFNRegressor(categorical_features_indices=[0],fit_mode = "fit_preprocessors",memory_saving_mode=4,n_jobs=-1,ignore_pretraining_limits=True)
      7 model.fit(X_train, y_train)
      9 # Make predictions
---> 10 y_pred = model.predict(X_test).sample(100)

File ~\TabPFN\src\tabpfn\regressor.py:624, in TabPFNRegressor.predict(self, X, output_type, quantiles)
    621 outputs: list[torch.Tensor] = []
    622 borders: list[np.ndarray] = []
--> 624 for output, config in self.executor_.iter_outputs(
    625     X,
    626     device=self.device_,
    627     autocast=self.use_autocast_,
    628 ):
    629     assert isinstance(config, RegressorEnsembleConfig)
    631     if self.softmax_temperature != 1:

File ~\TabPFN\src\tabpfn\inference.py:327, in InferenceEngineCachePreprocessing.iter_outputs(self, X, device, autocast)
    321     style = None
    323     with (
    324         torch.autocast(device.type, enabled=autocast),
    325         torch.inference_mode(),
    326     ):
--> 327         output = self.model(
    328             *(style, X_full, y_train),
    329             only_return_standard_out=True,
    330             categorical_inds=cat_ix,
    331             single_eval_pos=len(y_train),
    332         )
    333     yield output.squeeze(1), config
    335 self.model = self.model.cpu()

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\TabPFN\src\tabpfn\model\transformer.py:413, in PerFeatureTransformer.forward(self, *args, **kwargs)
    411 if len(args) == 3:
    412     style, x, y = args
--> 413     return self._forward(x, y, style=style, **kwargs)
    415 raise ValueError("Unrecognized input. Please follow the doc string.")

File ~\TabPFN\src\tabpfn\model\transformer.py:625, in PerFeatureTransformer._forward(***failed resolving arguments***)
    617     raise ValueError(
    618         f"There should be no NaNs in the encoded x and y."
    619         "Check that you do not feed NaNs or use a NaN-handling enocder."
    620         "Your embedded x and y returned the following:"
    621         f"{torch.isnan(embedded_x).any()=} | {torch.isnan(embedded_y).any()=}",
    622     )
    623 del embedded_y, embedded_x
--> 625 encoder_out = self.transformer_encoder(
    626     (
    627         embedded_input
    628         if not self.transformer_decoder
    629         else embedded_input[:, :single_eval_pos_]
    630     ),
    631     single_eval_pos=single_eval_pos,
    632     half_layers=half_layers,
    633     cache_trainset_representation=self.cache_trainset_representation,
    634 )  # b s f+1 e -> b s f+1 e
    636 # If we are using a decoder
    637 if self.transformer_decoder:

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\TabPFN\src\tabpfn\model\transformer.py:74, in LayerStack.forward(self, x, half_layers, **kwargs)
     72         x = checkpoint(partial(layer, **kwargs), x, use_reentrant=False)  # type: ignore
     73     else:
---> 74         x = layer(x, **kwargs)
     76 return x

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\TabPFN\src\tabpfn\model\layer.py:449, in PerFeatureEncoderLayer.forward(self, state, single_eval_pos, cache_trainset_representation, att_src)
    439     raise AssertionError(
    440         "Pre-norm implementation is wrong, as the residual should never"
    441         " be layer normed here.",
    442     )
    443     state = layer_norm(
    444         state,
    445         allow_inplace=True,
    446         save_peak_mem_factor=save_peak_mem_factor,
    447     )
--> 449 state = sublayer(state)
    450 if not self.pre_norm:
    451     state = layer_norm(
    452         state,
    453         allow_inplace=True,
    454         save_peak_mem_factor=save_peak_mem_factor,
    455     )

File ~\TabPFN\src\tabpfn\model\layer.py:363, in PerFeatureEncoderLayer.forward.<locals>.attn_between_items(x)
    360     new_x_test = None
    362 if single_eval_pos:
--> 363     new_x_train = self.self_attn_between_items(
    364         x[:, :single_eval_pos].transpose(1, 2),
    365         x[:, :single_eval_pos].transpose(1, 2),
    366         save_peak_mem_factor=save_peak_mem_factor,
    367         cache_kv=cache_trainset_representation,
    368         only_cache_first_head_kv=True,
    369         add_input=True,
    370         allow_inplace=True,
    371         use_cached_kv=False,
    372     ).transpose(1, 2)
    373 else:
    374     new_x_train = None

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\TabPFN\src\tabpfn\model\multi_head_attention.py:355, in MultiHeadAttention.forward(self, x, x_kv, cache_kv, add_input, allow_inplace, save_peak_mem_factor, reuse_first_head_kv, only_cache_first_head_kv, use_cached_kv, use_second_set_of_queries)
    338         self._k_cache = torch.empty(
    339             batch_size,
    340             seqlen_kv,
   (...)
    344             dtype=x.dtype,
    345         )
    346         self._v_cache = torch.empty(
    347             batch_size,
    348             seqlen_kv,
   (...)
    352             dtype=x.dtype,
    353         )
--> 355 output: torch.Tensor = self._compute(
    356     x,
    357     x_kv,
    358     self._k_cache,
    359     self._v_cache,
    360     self._kv_cache,
    361     cache_kv=cache_kv,
    362     use_cached_kv=use_cached_kv,
    363     add_input=add_input,
    364     allow_inplace=allow_inplace,
    365     save_peak_mem_factor=save_peak_mem_factor,
    366     reuse_first_head_kv=reuse_first_head_kv,
    367     use_second_set_of_queries=use_second_set_of_queries,
    368 )
    369 return output.reshape(x_shape_after_transpose[:-1] + output.shape[-1:])

File ~\TabPFN\src\tabpfn\model\memory.py:94, in support_save_peak_mem_factor.<locals>.method_(self, x, add_input, allow_inplace, save_peak_mem_factor, *args, **kwargs)
     92 for x_, *args_ in split_args:
     93     if add_input:
---> 94         x_[:] += method(self, x_, *args_, **kwargs)
     95     else:
     96         x_[:] = method(self, x_, *args_, **kwargs)

File ~\TabPFN\src\tabpfn\model\multi_head_attention.py:504, in MultiHeadAttention._compute(self, x, x_kv, k_cache, v_cache, kv_cache, cache_kv, use_cached_kv, reuse_first_head_kv, use_second_set_of_queries)
    490 """Attention computation.
    491 Called by 'forward', potentially on shards, once shapes have been normalized.
    492 """
    493 q, k, v, kv, qkv = self.compute_qkv(
    494     x,
    495     x_kv,
   (...)
    502     use_second_set_of_queries=use_second_set_of_queries,
    503 )
--> 504 attention_head_outputs = MultiHeadAttention.compute_attention_heads(
    505     q,
    506     k,
    507     v,
    508     kv,
    509     qkv,
    510     self.dropout_p,
    511     self.softmax_scale,
    512 )
    513 return torch.einsum(
    514     "... h d, h d s -> ... s",
    515     attention_head_outputs,
    516     self._w_out,
    517 )

File ~\TabPFN\src\tabpfn\model\multi_head_attention.py:721, in MultiHeadAttention.compute_attention_heads(q, k, v, kv, qkv, dropout_p, softmax_scale)
    719 k = MultiHeadAttention.broadcast_kv_across_heads(k, share_kv_across_n_heads)
    720 v = MultiHeadAttention.broadcast_kv_across_heads(v, share_kv_across_n_heads)
--> 721 logits = torch.einsum("b q h d, b k h d -> b q k h", q, k)
    722 logits *= (
    723     torch.sqrt(torch.tensor(1.0 / d_k)).to(k.device)
    724     if softmax_scale is None
    725     else softmax_scale
    726 )
    727 ps = torch.softmax(logits, dim=2)

File ~\miniconda3\envs\TabPFN\lib\site-packages\torch\functional.py:402, in einsum(*args)
    397     return einsum(equation, *_operands)
    399 if len(operands) <= 2 or not opt_einsum.enabled:
    400     # the path for contracting 0 or 1 time(s) is already optimized
    401     # or the user has disabled using opt_einsum
--> 402     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    404 path = None
    405 if opt_einsum.is_available():

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 54629649408 bytes.

@eddiebergman
Copy link
Collaborator

eddiebergman commented Jan 9, 2025

Hiyo, sorry to hear it didn't help. You could try the memory_saving_mode= parameter which can take in a float | int which specifies the amount of memory available in GB.

            memory_saving_mode:
                Enable GPU/CPU memory saving mode. This can help to prevent
                out-of-memory errors that result from computations that would consume
                more memory than available on the current device. We save memory by
                automatically batching certain model computations within TabPFN to
                reduce the total required memory. The options are:

                - If `bool`, enable/disable memory saving mode.
                - If `"auto"`, we will estimate the amount of memory required for the
                  forward pass and apply memory saving if it is more than the
                  available GPU/CPU memory. This is the recommended setting as it
                  allows for speed-ups and prevents memory errors depending on
                  the input data.
                - If `float` or `int`, we treat this value as the maximum amount of
                  available GPU/CPU memory (in GB). We will estimate the amount
                  of memory required for the forward pass and apply memory saving
                  if it is more than this value. Passing a float or int value for
                  this parameter is the same as setting it to True and explicitly
                  specifying the maximum free available memory

                !!! warning
                    This does not batch the original input data. We still recommend to
                    batch this as necessary if you run into memory errors! For example,
                    if the entire input data does not fit into memory, even the memory
                    save mode will not prevent memory errors.


Out of curiosity, what is the amount of RAM available on your system?

I hope you don't mind that I change the title of the issue to better reflect the current issue!

@eddiebergman eddiebergman changed the title [Windows] Running model.predict() on TabPFNRegressor gives error [Windows] Memory error with model.predict() on TabPFNRegressor Jan 9, 2025
@realshaktigupta
Copy link
Author

I am already using the memory_saving_mode= parameter in the latest code where I got the memory related error.
This is the latest model initialization I tried using.

model = TabPFNRegressor(categorical_features_indices=[0],fit_mode = "fit_preprocessors",memory_saving_mode=4,n_jobs=-1,ignore_pretraining_limits=True)

I also tried using memory_saving_mode=8 and memory_saving_mode=12 but both didn't help.

I have 16 GB RAM on my system. Also, I have used 12000 rows and 68 columns of data to train the model.

@eddiebergman
Copy link
Collaborator

@noahho or @LennartPurucker I believe this on the larger side of what's possible, correct me I'm wrong. I also believe you will run into very long inference times trying to do so on a CPU.

https://priorlabs.ai/getting_started/intended_use/#hyperparameter-tuning

@davitens
Copy link

davitens commented Jan 9, 2025

I have 16GB RAM on my system but working with 4000 rows and 20 columns, and having the same insane long times with CPU. Memory_saving_mode returns
RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 24387060000 bytes.

@noahho
Copy link
Collaborator

noahho commented Jan 9, 2025

Running on CPU will be very slow with larger loads and I would not recommend. @realshaktigupta for the dataset size you are looking at a CPU would take too long and wouldn't be suitable. If using our web API is an option I would recommend that in the meantime (https://github.com/PriorLabs/tabpfn-client) or using our models via a colab notebook with GPU support.

The issue running on CPU is twofold: first the hardware isn't made for the massively parallel compute load of transformed-based models, second many optimizations for calculating attention can't be used (e.g. flashattention, fp16 precision).

@eddiebergman @LennartPurucker I'm wondering if we should print warning for larger datasets when run on CPU?

@noahho noahho removed the bug Something isn't working label Jan 9, 2025
@LennartPurucker LennartPurucker added the documentation Improvements or additions to documentation label Jan 9, 2025
@LennartPurucker
Copy link
Collaborator

A warning that points to collab and the client would be a good idea.

@davitens @realshaktigupta You could try to use forced_inference_dtype_=torch.float32 (or lower), see https://priorlabs.ai/reference/tabpfn/classifier/#tabpfn.classifier.TabPFNClassifier.forced_inference_dtype_ and set n_jobs=1. I somehow suspect that autocast and the dtype is a problem here, since a 4000 x 20 matrix should not use 24 GB CPU RAM.

Generally, our options are quickly exhausted if using Torch for CPU by default. Especially if the CPU does not support autocast.

@realshaktigupta
Copy link
Author

realshaktigupta commented Jan 11, 2025

Yes, it was a dtype related issue. I managed to run it by setting forced_inference_dtype_=torch.bfloat16 and reducing number of rows to 2500. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

6 participants