forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add UC-OSS integration test workflow (mlflow#13804)
- Loading branch information
1 parent
1dad595
commit 874176e
Showing
2 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
name: uc-oss | ||
|
||
on: | ||
pull_request: # for testing | ||
types: | ||
- opened | ||
- synchronize | ||
- reopened | ||
- ready_for_review | ||
paths: | ||
- .github/workflows/uc-oss.yml | ||
- mlflow/protos/** | ||
- mlflow/store/** | ||
schedule: | ||
# Run this workflow daily at 13:00 UTC | ||
- cron: "0 13 * * *" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
defaults: | ||
run: | ||
shell: bash --noprofile --norc -exo pipefail {0} | ||
|
||
env: | ||
MLFLOW_HOME: /home/runner/work/mlflow/mlflow | ||
|
||
jobs: | ||
uc-oss-integration-test: | ||
runs-on: ubuntu-latest | ||
timeout-minutes: 30 | ||
permissions: {} | ||
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'schedule' && github.repository == 'mlflow-automation/mlflow') || (github.event_name == 'pull_request' && github.event.pull_request.draft == false) | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
repository: ${{ github.event_name == 'schedule' && 'mlflow/mlflow' || github.event.inputs.repository }} | ||
ref: ${{ github.event.inputs.ref }} | ||
submodules: recursive | ||
- uses: ./.github/actions/setup-python | ||
|
||
- name: Install dependencies | ||
run: | | ||
source ./dev/install-common-deps.sh --ml | ||
- name: Set up Java 17 | ||
uses: actions/setup-java@v3 | ||
with: | ||
java-version: "17" | ||
distribution: "temurin" # Use Temurin distribution of OpenJDK | ||
|
||
- name: Clone UnityCatalog at tag v0.2.1 | ||
run: | | ||
git clone --branch v0.2.1 --depth 1 https://github.com/unitycatalog/unitycatalog.git | ||
- name: Build uc-oss server | ||
working-directory: unitycatalog | ||
run: | | ||
build/sbt package | ||
- name: Run tests for UnityCatalog | ||
run: | | ||
export UC_OSS_INTEGRATION=true | ||
pytest tests/uc_oss/test_uc_oss_integration.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
|
||
import pandas as pd | ||
import pytest | ||
from sklearn import datasets | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.model_selection import train_test_split | ||
|
||
import mlflow | ||
from mlflow.exceptions import MlflowException | ||
|
||
from tests.helper_functions import get_safe_port | ||
from tests.tracking.integration_test_utils import _await_server_up_or_die | ||
|
||
pytestmark = pytest.mark.skipif( | ||
"UC_OSS_INTEGRATION" not in os.environ, | ||
reason="This test is only valid w/in the github workflow integration job", | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def setup_servers(): | ||
port = get_safe_port() | ||
with ( | ||
subprocess.Popen( | ||
["bin/start-uc-server"], | ||
cwd="unitycatalog", | ||
) as uc_proc, | ||
subprocess.Popen( | ||
[sys.executable, "-m", "mlflow", "server", "--port", str(port)] | ||
) as mlflow_proc, | ||
): | ||
try: | ||
_await_server_up_or_die(port) | ||
_await_server_up_or_die(8080) | ||
|
||
mlflow_tracking_url = f"http://127.0.0.1:{port}" | ||
uc_oss_url = "uc:http://127.0.0.1:8080" | ||
|
||
mlflow.set_tracking_uri(mlflow_tracking_url) | ||
mlflow.set_registry_uri(uc_oss_url) | ||
|
||
yield mlflow_tracking_url | ||
finally: | ||
mlflow_proc.terminate() | ||
uc_proc.terminate() | ||
|
||
|
||
def test_integration(setup_servers, tmp_path): | ||
catalog = "unity" | ||
schema = "default" | ||
registered_model_name = "iris" | ||
model_name = f"{catalog}.{schema}.{registered_model_name}" | ||
mlflow.set_experiment("iris-uc-oss") | ||
client = mlflow.MlflowClient() | ||
with pytest.raises(MlflowException, match="NOT_FOUND"): | ||
client.get_registered_model(model_name) | ||
|
||
X, y = datasets.load_iris(return_X_y=True, as_frame=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | ||
|
||
with mlflow.start_run(): | ||
# Train a sklearn model on the iris dataset | ||
clf = RandomForestClassifier(max_depth=7) | ||
clf.fit(X_train, y_train) | ||
# Take the first row of the training dataset as the model input example. | ||
input_example = X.iloc[[0]] | ||
# Log the model and register it as a new version in UC. | ||
mlflow.sklearn.log_model( | ||
clf, | ||
"model", | ||
# The signature is automatically inferred from the input example and | ||
# its predicted output. | ||
input_example=input_example, | ||
registered_model_name=model_name, | ||
) | ||
|
||
model_version = 1 | ||
model_uri = f"models:/{model_name}/{model_version}" | ||
rm_desc = "UC-OSS/MLflow Iris model" | ||
mv_desc = "Version 1 of the UC-OSS/MLflow Iris model" | ||
|
||
# Load the model and do some batch inference. | ||
# By specifying the UC OSS model uri, mlflow will make UC OSS | ||
# REST API calls to retrieve the model | ||
loaded_model = mlflow.pyfunc.load_model(model_uri) | ||
predictions = loaded_model.predict(X_test) | ||
iris_feature_names = datasets.load_iris().feature_names | ||
result = pd.DataFrame(X_test, columns=iris_feature_names) | ||
result["actual_class"] = y_test | ||
result["predicted_class"] = predictions | ||
assert result[:4] is not None | ||
|
||
# list_artifacts will use the UC OSS model URI and make REST API calls to | ||
# UC OSS to: | ||
# 1) retrieve credentials (none for file based UC OSS) | ||
# 2) use the storage location returned from UC OSS for the model version | ||
# list the artifacts stored in the location | ||
mlflow.artifacts.list_artifacts(model_uri) | ||
|
||
path = os.path.join(tmp_path, "models", model_name, str(model_version)) | ||
|
||
# download_artifacts will use the UC OSS model URI and make REST API calls | ||
# to UC OSS to: | ||
# 1) retrieve credentials (none for file based UC OSS) | ||
# 2) copy the artifact files from the storage location to the | ||
# destination path | ||
mlflow.artifacts.download_artifacts( | ||
artifact_uri=f"models:/{model_name}/{model_version}", | ||
dst_path=path, | ||
) | ||
requirements_path = f"{path}/requirements.txt" | ||
assert os.path.exists(requirements_path), f"File {requirements_path} does not exist." | ||
with open(requirements_path) as file: | ||
lines = file.readlines() | ||
assert len(lines) > 0 | ||
|
||
# Test get RM/MV works | ||
model1 = client.get_registered_model(model_name) | ||
assert model1.name == model_name | ||
assert model1.description == "" | ||
model_v1 = client.get_model_version(name=model_name, version=model_version) | ||
assert model_v1.name == model_name | ||
assert model_v1.version == 1 | ||
assert model_v1.description == "" | ||
|
||
# Test update RM/MV works | ||
client.update_registered_model(model_name, description=rm_desc) | ||
model2 = mlflow.MlflowClient().get_registered_model(model_name) | ||
assert model2.name == model_name | ||
assert model2.description == rm_desc | ||
client.update_model_version(name=model_name, version=model_version, description=mv_desc) | ||
model_v1_2 = client.get_model_version(name=model_name, version=model_version) | ||
assert model_v1_2.name == model_name | ||
assert model_v1_2.version == 1 | ||
assert model_v1_2.description == mv_desc | ||
|
||
rms = client.search_registered_models() | ||
assert len(rms) == 1 | ||
mvs = client.search_model_versions(f"name='{model_name}'") | ||
assert len(mvs) == 1 | ||
client.delete_model_version(name=model_name, version=1) | ||
mvs = client.search_model_versions(f"name='{model_name}'") | ||
assert len(mvs) == 0 | ||
client.delete_registered_model(name=model_name) | ||
rms = client.search_registered_models() | ||
assert len(rms) == 0 | ||
with pytest.raises(MlflowException, match="NOT_FOUND"): | ||
client.get_registered_model(model_name) |