Skip to content

Commit

Permalink
Minor cleanup of UC get_object_size (#2989)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Feb 9, 2024
1 parent 2fd6dd6 commit 1a22691
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 3 additions & 3 deletions composer/utils/object_store/uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def get_object_size(self, object_name: str) -> int:
# Note: The UC team is working on changes to fix the files.get_status API, but it currently
# does not work. Once fixed, we will call the files API endpoint. We currently only use this
# function in Composer and LLM-foundry to check the UC object's existence.
self.client.api_client.do(method='HEAD',
path=f'{self._UC_VOLUME_FILES_API_ENDPOINT}/{self._get_object_path(object_name)}',
headers={'Source': 'mosaicml/composer'})
object_path = self._get_object_path(object_name).lstrip('/')
path = os.path.join(self._UC_VOLUME_FILES_API_ENDPOINT, object_path)
self.client.api_client.do(method='HEAD', path=path, headers={'Source': 'mosaicml/composer'})
return 1000000 # Dummy value, as we don't have a way to get the size of the file
except DatabricksError as e:
# If the code reaches here, the file was not found
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/object_store/test_uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def test_get_object_size(ws_client, uc_object_store, result: str):
raise NotImplementedError(f'Test for result={result} is not implemented.')


def test_get_object_size_full_path(ws_client, uc_object_store):
ws_client.api_client.do.return_value = {}
assert uc_object_store.get_object_size('Volumes/catalog/schema/volume/train.txt') == 1000000
ws_client.api_client.do.assert_called_with(method='HEAD',
path=f'/api/2.0/fs/files/Volumes/catalog/schema/volume/train.txt',
headers={'Source': 'mosaicml/composer'})


def test_get_uri(uc_object_store):
assert uc_object_store.get_uri('train.txt') == 'dbfs:/Volumes/catalog/schema/volume/train.txt'
assert uc_object_store.get_uri('Volumes/catalog/schema/volume/checkpoint/model.bin'
Expand Down

0 comments on commit 1a22691

Please sign in to comment.