Skip to content

Commit

Permalink
fix get best cfg in sft and lora mode bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Sep 18, 2023
1 parent fda5145 commit b1bc3eb
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,28 @@ def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
return (self.history[0], True)
if mode == "SFT" or mode == "LoRA":
best_cfg = self.history[0]
if (
isinstance(best_cfg["max_mem_usage"], str)
or best_cfg["time"] == -1
):
return (best_cfg, True)
first_few = 1
for cfg in self.history:
if (
cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
not isinstance(cfg["max_mem_usage"], str)
and cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
and cfg["time"] != -1
):
best_cfg = cfg
first_few += 1
if first_few >= 5:
break
return (best_cfg, False)

if (
isinstance(self.history[0]["max_mem_usage"], str)
or self.history[0]["time"] == -1
):
return (self.history[0], False)
return (self.history[0], False)

def store_history(self, path="./history.csv"):
Expand Down

0 comments on commit b1bc3eb

Please sign in to comment.