diff --git a/adapta/ml/mlflow/_client.py b/adapta/ml/mlflow/_client.py index 1d3592b4..86083f53 100644 --- a/adapta/ml/mlflow/_client.py +++ b/adapta/ml/mlflow/_client.py @@ -151,7 +151,7 @@ def load_model_by_uri(model_uri: str) -> PyFuncModel: - ``/Users/me/path/to/local/model`` - ``relative/path/to/local/model`` - - ``s3://my_bucket/path/to/model`` + - ``s3a://my_bucket/path/to/model`` - ``runs://run-relative/path/to/model`` - ``models://`` - ``models://`` diff --git a/adapta/storage/models/aws.py b/adapta/storage/models/aws.py index 3c1695fb..d6264db7 100644 --- a/adapta/storage/models/aws.py +++ b/adapta/storage/models/aws.py @@ -16,6 +16,8 @@ # limitations under the License. # +import re + from dataclasses import dataclass from urllib.parse import urlparse @@ -33,9 +35,6 @@ def to_uri(self) -> str: Converts the S3Path to a URI. :return: URI path """ - if not self.bucket or not self.path: - raise ValueError("Bucket and path must be defined") - return f"s3://{self.bucket}/{self.path}" def base_uri(self) -> str: @@ -43,9 +42,6 @@ def base_uri(self) -> str: Returns the base URI of the S3Path. :return: URI path """ - if not self.bucket: - raise ValueError("Bucket must be defined") - return f"https://{self.bucket}.s3.amazonaws.com" @classmethod @@ -62,6 +58,15 @@ def from_uri(cls, url: str) -> "S3Path": path: str protocol: str = DataProtocols.S3.value + def __post_init__(self): + if not self.bucket: + raise ValueError("Bucket must be defined") + + path_regex = r"//" + + if re.search(path_regex, self.path): + raise ValueError("Invalid S3Path provided: path should not contain consecutive slashes (//)") + @classmethod def from_hdfs_path(cls, hdfs_path: str) -> "S3Path": """ @@ -78,9 +83,6 @@ def to_hdfs_path(self) -> str: Converts the S3Path to an HDFS compatible path. :return: HDFS path """ - if not self.bucket or not self.path: - raise ValueError("Bucket and path must be defined") - return f"s3a://{self.bucket}/{self.path}" def to_delta_rs_path(self) -> str: diff --git a/tests/test_s3_storage_client.py b/tests/test_s3_storage_client.py index 62925be8..2ff4a55b 100644 --- a/tests/test_s3_storage_client.py +++ b/tests/test_s3_storage_client.py @@ -12,10 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import pytest from adapta.storage.blob.s3_storage_client import S3StorageClient from adapta.storage.models.aws import S3Path -from unittest.mock import patch, MagicMock +from unittest.mock import patch + + +def test_valid_s3_datapath(): + valid_s3_datapaths = [ + lambda: S3Path(bucket="bucket", path=""), + lambda: S3Path(bucket="bucket", path="path"), + lambda: S3Path(bucket="bucket", path="path/"), + lambda: S3Path(bucket="bucket", path="path/path_segment"), + lambda: S3Path(bucket="bucket", path="path/path_segment/path_segment"), + ] + + for s3_data_path in valid_s3_datapaths: + try: + s3_data_path() + except Exception as e: + pytest.fail(f"S3Path creation raised an exception: {e}") + + +def test_invalid_s3_datapath(): + malformed_s3_datapaths = [ + lambda: S3Path(bucket="bucket", path="path//path_segment"), + lambda: S3Path(bucket="bucket", path="path/path_segment//path_segment"), + ] + + for s3_data_path in malformed_s3_datapaths: + with pytest.raises(ValueError, match=r"Invalid S3Path provided: .*"): + s3_data_path() + + +def test_base_uri(): + path = S3Path(bucket="bucket", path="nested/key") + assert path.base_uri() == "https://bucket.s3.amazonaws.com" def test_from_hdfs_path(): @@ -24,6 +57,20 @@ def test_from_hdfs_path(): assert path.path == "nested/key" +def test_to_uri(): + bucket_name = "bucket" + path = "nested/key" + path_instance = S3Path(bucket=bucket_name, path=path) + assert path_instance.to_uri() == f"s3://{bucket_name}/{path}" + + +def test_to_delta_rs_path(): + bucket_name = "bucket" + path = "nested/key" + path_instance = S3Path(bucket=bucket_name, path=path) + assert path_instance.to_delta_rs_path() == f"s3a://bucket/nested/key" + + def test_to_hdfs_path(): path = S3Path("bucket", "nested/key").to_hdfs_path() assert path == "s3a://bucket/nested/key"