diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 03070c28f9..aed32eea39 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -150,7 +150,12 @@ def __init__( ) assert self.experiment_name is not None # type hint - if os.getenv('DATABRICKS_TOKEN') is not None and not self.experiment_name.startswith('/Users/'): + if os.getenv( + 'DATABRICKS_TOKEN', + ) is not None and not self.experiment_name.startswith(( + '/Users/', + '/Shared/', + )): try: from databricks.sdk import WorkspaceClient except ImportError as e: @@ -160,7 +165,7 @@ def __init__( conda_channel='conda-forge', ) from e databricks_username = WorkspaceClient().current_user.me().user_name or '' - self.experiment_name = '/' + os.path.join('Users', databricks_username, self.experiment_name) + self.experiment_name = os.path.join('/Users', databricks_username, self.experiment_name.strip('/')) self._mlflow_client = MlflowClient(self.tracking_uri) # Set experiment