Skip to content

Commit

Permalink
Add refit_full logging (open-mmlab#2913)
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma authored Feb 14, 2023
1 parent 727d674 commit 8c70c0c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tabular/src/autogluon/tabular/models/knn/knn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ def _estimate_memory_usage(self, X, **kwargs):
return expected_final_model_size_bytes

def _validate_fit_memory_usage(self, **kwargs):
max_memory_safety_proportion = 0.2
max_memory_usage_ratio = self.params_aux['max_memory_usage_ratio']
expected_final_model_size_bytes = self.estimate_memory_usage(**kwargs)
if expected_final_model_size_bytes > 10000000: # Only worth checking if expected model size is >10MB
available_mem = ResourceManager.get_available_virtual_mem()
model_memory_ratio = expected_final_model_size_bytes / available_mem
if model_memory_ratio > (0.15 * max_memory_usage_ratio):
logger.warning(f'\tWarning: Model is expected to require {round(model_memory_ratio * 100, 2)}% of available memory...')
if model_memory_ratio > (0.20 * max_memory_usage_ratio):
logger.warning(f'\tWarning: Model is expected to require {round(model_memory_ratio * 100, 2)}% of available memory... '
f'({max_memory_safety_proportion*100}% is the max safe size.)')
if model_memory_ratio > (max_memory_safety_proportion * max_memory_usage_ratio):
raise NotEnoughMemoryError # don't train full model to avoid OOM error

# TODO: Won't work for RAPIDS without modification
Expand Down
8 changes: 8 additions & 0 deletions tabular/src/autogluon/tabular/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ def _post_fit(self, keep_only_best=False, refit_full=False, set_best_to_refit_fu
if infer_limit is not None:
infer_limit = infer_limit - self._learner.preprocess_1_time
trainer_model_best = self._trainer.get_model_best(infer_limit=infer_limit)
logger.log(20, 'Automatically performing refit_full as a post-fit operation (due to `.fit(..., refit_full=True)`')
self.refit_full(model=refit_full, set_best_to_refit_full=False)
if set_best_to_refit_full:
model_full_dict = self._trainer.get_model_full_dict()
Expand Down Expand Up @@ -2264,7 +2265,12 @@ def refit_full(self, model='all', set_best_to_refit_full=True):
Dictionary of original model names -> refit_full model names.
"""
self._assert_is_fit('refit_full')
ts = time.time()
model_best = self._get_model_best(can_infer=None)
logger.log(20, 'Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...\n'
'\tModels trained in this way will have the suffix "_FULL" and have NaN validation score.\n'
'\tThis process is not bound by time_limit, but should take less time than the original `predictor.fit` call.\n'
'\tTo learn more, refer to the `.refit_full` method docstring which explains how "_FULL" models differ from normal models.')
refit_full_dict = self._learner.refit_ensemble_full(model=model)

if set_best_to_refit_full:
Expand All @@ -2289,6 +2295,8 @@ def refit_full(self, model='all', set_best_to_refit_full=True):
f'Best model ("{model_best}") is not present in refit_full dictionary. '
f'Training may have failed on the refit model. AutoGluon will default to using "{model_best}" for predict() and predict_proba().')

te = time.time()
logger.log(20, f'Refit complete, total runtime = {round(te - ts, 2)}s')
return refit_full_dict

def get_model_best(self):
Expand Down

0 comments on commit 8c70c0c

Please sign in to comment.