Skip to content

Commit

Permalink
fix autoresume with slashed directory
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha committed Jun 6, 2023
1 parent 2f864f6 commit 3dfb5f5
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 3 deletions.
3 changes: 2 additions & 1 deletion composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ def download_file(
destination=destination,
object_store=self.remote_backend,
overwrite=overwrite,
progress_bar=progress_bar)
progress_bar=progress_bar,
marker=self.remote_backend.__class__.__name__ + 'bar')

def fit_end(self, state: State, logger: Logger):
self.wait_for_workers()
Expand Down
1 change: 1 addition & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,7 @@ def _try_checkpoint_download(self, latest_checkpoint_path: str, save_latest_remo
object_store=logger,
overwrite=True,
progress_bar=load_progress_bar,
marker=logger.__class__.__name__,
)
break
except (NotImplementedError, FileNotFoundError):
Expand Down
3 changes: 2 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def download_checkpoint(path: str,
get_file(destination=rank_zero_checkpoint_filepath,
path=path,
object_store=object_store,
progress_bar=progress_bar)
progress_bar=progress_bar,
marker=object_store.__class__.__name__ + 'foo')
if extracted_checkpoint_folder is not None:
try:
with tarfile.open(rank_zero_checkpoint_filepath) as tarball:
Expand Down
3 changes: 3 additions & 0 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def get_file(
object_store: Optional[Union[ObjectStore, LoggerDestination]] = None,
overwrite: bool = False,
progress_bar: bool = True,
marker: Optional[str] = '',
):
"""Get a file from a local folder, URL, or object store.
Expand Down Expand Up @@ -487,6 +488,7 @@ def get_file(
object_store=object_store,
overwrite=overwrite,
progress_bar=progress_bar,
marker=marker,
)

try:
Expand All @@ -507,6 +509,7 @@ def get_file(
object_store=object_store,
overwrite=overwrite,
progress_bar=progress_bar,
marker=marker,
)
except FileNotFoundError as ee:
# Raise the original not found error first, which contains the path to the user-specified file
Expand Down
4 changes: 4 additions & 0 deletions composer/utils/object_store/libcloud_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def download_object(
# If the file already exits, short-circuit and skip the download
raise FileExistsError(f'filename {filename} exists and overwrite was set to False.')

dirname = os.path.dirname(filename)
if dirname:
os.makedirs(dirname, exist_ok=True)
obj = self._get_object(object_name)
# Download first to a tempfile, and then rename, in case if the file gets corrupted in transit
tmp_filepath = str(filename) + f'.{uuid.uuid4()}.tmp'
Expand All @@ -179,6 +182,7 @@ def download_object(
for chunk in iterate_with_callback(stream, obj.size, callback):
f.write(chunk)
except Exception as e:
print('yup')
# The download failed for some reason. Make a best-effort attempt to remove the temporary file.
try:
os.remove(tmp_filepath)
Expand Down
4 changes: 4 additions & 0 deletions composer/utils/object_store/oci_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def download_object(
del callback
if os.path.exists(filename) and not overwrite:
raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False')

dirname = os.path.dirname(filename)
if dirname:
os.makedirs(dirname, exist_ok=True)
tmp_path = str(filename) + f'.{uuid.uuid4()}.tmp'

try:
Expand Down
5 changes: 5 additions & 0 deletions composer/utils/object_store/s3_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ def download_object(
):
if os.path.exists(filename) and not overwrite:
raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.')

dirname = os.path.dirname(filename)
if dirname:
os.makedirs(dirname, exist_ok=True)
tmp_path = str(filename) + f'.{uuid.uuid4()}.tmp'

if callback is None:
cb_wrapper = None
else:
Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _metrics_equal(self, train_metrics_1, train_metrics_2, eval_metrics_1, eval_
except AssertionError:
return False

def get_trainer(self, model=None, max_duration='2ep', **kwargs):
def get_trainer(self, model=None, max_duration='2ep', latest_filename='latest-rank{rank}.pt', **kwargs):
if model is None:
model = SimpleConvModel()
optimizer = torch.optim.Adam(model.parameters())
Expand All @@ -422,6 +422,7 @@ def get_trainer(self, model=None, max_duration='2ep', **kwargs):
eval_subset_num_batches=1,
save_interval='1ep',
eval_interval='1ep',
save_latest_filename=latest_filename,
save_filename='ep{epoch}.pt',
max_duration=max_duration,
optimizers=optimizer,
Expand Down Expand Up @@ -663,6 +664,7 @@ def test_autoresume(self, device: str, tmp_path: pathlib.Path, use_object_store:
pytest.importorskip('libcloud')

trainer_1 = self.get_trainer(
latest_filename='testdir/latest-rank{rank}.pt',
save_folder='first',
device=device,
run_name='big-chungus',
Expand All @@ -678,6 +680,7 @@ def test_autoresume(self, device: str, tmp_path: pathlib.Path, use_object_store:
shutil.rmtree('first')

trainer_2 = self.get_trainer(
latest_filename='testdir/latest-rank{rank}.pt',
save_folder='first',
device=device,
run_name='big-chungus',
Expand Down

0 comments on commit 3dfb5f5

Please sign in to comment.