Skip to content

Commit

Permalink
Add upload timeout patch to mlflow on azure (mosaicml#3265)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 10, 2024
1 parent e625a06 commit d895d56
Showing 1 changed file with 96 additions and 1 deletion.
97 changes: 96 additions & 1 deletion composer/utils/object_store/mlflow_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,99 @@ def _wrap_mlflow_exceptions(uri: str, e: Exception):
raise e


def _get_timeout_and_set_socket_default() -> Optional[int]:
timeout = os.environ.get('MLFLOW_PATCHED_FILE_UPLOAD_TIMEOUT', None)
if timeout is not None:
import socket
timeout = int(timeout)
socket.setdefaulttimeout(timeout)
return timeout


# Original source: https://github.com/mlflow/mlflow/blob/a85081631eb665fa25046cb0b7daf0fbbdd5949f/mlflow/azure/client.py#L42
def _patch_adls_file_upload_with_timeout(sas_url, local_file, start_byte, size, position, headers, is_single):
"""Performs an ADLS Azure file create `Patch` operation.
(https://docs.microsoft.com/en-us/rest/api/storageservices/datalakestoragegen2/path/update)
Args:
sas_url: A shared access signature URL referring to the Azure ADLS server
to which the file update command should be issued.
local_file: The local file to upload
start_byte: The starting byte of the local file to upload
size: The number of bytes to upload
position: Positional offset of the data in the Patch request
headers: Additional headers to include in the Patch request body
is_single: Whether this is the only patch operation for this file
"""
from mlflow.azure.client import _append_query_parameters, _is_valid_adls_patch_header, _logger
from mlflow.utils import rest_utils
from mlflow.utils.file_utils import read_chunk

new_params = {'action': 'append', 'position': str(position)}
if is_single:
new_params['flush'] = 'true'
request_url = _append_query_parameters(sas_url, new_params)

request_headers = {}
for name, value in headers.items():
if _is_valid_adls_patch_header(name):
request_headers[name] = value
else:
_logger.debug("Removed unsupported '%s' header for ADLS Gen2 Patch operation", name)

data = read_chunk(local_file, size, start_byte)

### Changed here to pass a timeout along to cloud_storage_http_request
### And to set the socket timeout
timeout = _get_timeout_and_set_socket_default()
with rest_utils.cloud_storage_http_request(
'patch',
request_url,
data=data,
headers=request_headers,
timeout=timeout,
) as response:
rest_utils.augmented_raise_for_status(response)


def _put_adls_file_creation_with_timeout(sas_url, headers):
"""Performs an ADLS Azure file create `Put` operation.
(https://docs.microsoft.com/en-us/rest/api/storageservices/datalakestoragegen2/path/create)
:param sas_url: A shared access signature URL referring to the Azure ADLS server
to which the file creation command should be issued.
:param headers: Additional headers to include in the Put request body
"""
from mlflow.azure.client import _append_query_parameters, _is_valid_adls_put_header, _logger
from mlflow.utils import rest_utils

request_url = _append_query_parameters(sas_url, {'resource': 'file'})

request_headers = {}
for name, value in headers.items():
if _is_valid_adls_put_header(name):
request_headers[name] = value
else:
_logger.debug("Removed unsupported '%s' header for ADLS Gen2 Put operation", name)

### Changed here to pass a timeout along to cloud_storage_http_request
### And to set the socket timeout
timeout = _get_timeout_and_set_socket_default()
with rest_utils.cloud_storage_http_request(
'put',
request_url,
headers=request_headers,
timeout=timeout,
) as response:
rest_utils.augmented_raise_for_status(response)


class MLFlowObjectStore(ObjectStore):
"""Utility class for uploading and downloading artifacts from MLflow.
It can be initializd for an existing run, a new run in an existing experiment, the active run used by the `mlflow`
It can be initialized for an existing run, a new run in an existing experiment, the active run used by the `mlflow`
module, or a new run in a new experiment. See the documentation for ``path`` for more details.
.. note::
Expand Down Expand Up @@ -127,6 +216,12 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10
except ImportError as e:
raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e

# This is a temporary workaround for an intermittent hang we have encountered when uploading files to ADLS.
# MLflow is working on an upstream fix, but in the meantime, patching in timeouts works around the hang.
log.debug('Patching MLflow Azure client to include timeout in ADLS file upload')
mlflow.store.artifact.databricks_artifact_repo.patch_adls_file_upload = _patch_adls_file_upload_with_timeout # type: ignore
mlflow.store.artifact.databricks_artifact_repo.put_adls_file_creation = _put_adls_file_creation_with_timeout # type: ignore

tracking_uri = os.getenv(
mlflow.environment_variables.MLFLOW_TRACKING_URI.name, # pyright: ignore[reportGeneralTypeIssues]
MLFLOW_DATABRICKS_TRACKING_URI,
Expand Down

0 comments on commit d895d56

Please sign in to comment.