Skip to content

Commit

Permalink
Port more functions from MLFlow and other minor improvements (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt035343 authored Jan 17, 2025
1 parent 90a4a0e commit 9c7eb84
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
51 changes: 50 additions & 1 deletion adapta/ml/mlflow/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_model_version_by_alias(self, model_name: str, alias: str) -> ModelVersio
"""
return self._client.get_model_version_by_alias(model_name, alias)

def _get_artifact_repo_backported(self, run_id) -> mlflow.store.artifact_repo.ArtifactRepository:
def _get_artifact_repo_backported(self, run_id: str) -> mlflow.store.artifact_repo.ArtifactRepository:
run = self._client.get_run(run_id)

artifact_uri = (
Expand Down Expand Up @@ -137,6 +137,55 @@ def set_model_alias(self, model_name: str, alias: str, model_version: Optional[s
version=model_version,
)

def log_dict(self, artifact: dict, artifact_path: str, run_id: str):
"""
inherited the logging dictionary in Mlflow
:param artifact: dictionary to log
:param artifact_path: artifact path
:param run_id: run id
"""
self._client.log_dict(run_id=run_id, dictionary=artifact, artifact_file=artifact_path)

def log_metric(self, run_id: str, metric_name: str, metric_value: float):
"""
inherited the logging metric in Mlflow
:param run_id: run id
:param metric_name: metric name
:param metric_value: metric value
"""
self._client.log_metric(run_id=run_id, key=metric_name, value=metric_value)

def create_run(self, experiment_name: str, run_name: str) -> str:
"""
inherited the creating run in Mlflow
:param experiment_name: experiment name
:param run_name: run name
:return: run id
"""
experiment = self._client.get_experiment_by_name(experiment_name)
return self._client.create_run(experiment_id=experiment.experiment_id, run_name=run_name).info.run_id

def terminate_run(self, run_id: str):
"""
inherited the stopping run in Mlflow
:param run_id: run id
"""
self._client.set_terminated(run_id)

def set_run_tag(self, key: str, value: any, run_id: str):
"""
inherited the setting run tag in Mlflow
:param key: tag key
:param value: tag value
:param run_id: run id
"""
self._client.set_tag(run_id=run_id, key=key, value=value)

@staticmethod
def load_model_by_name(model_name: str, stage_or_version: str) -> PyFuncModel:
"""
Expand Down
25 changes: 17 additions & 8 deletions adapta/ml/mlflow/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import importlib
import pathlib
import tempfile
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Literal

import mlflow
from mlflow.pyfunc import PythonModel
Expand Down Expand Up @@ -51,27 +51,34 @@ def register_mlflow_model(
mlflow_client: MlflowBasicClient,
model_name: str,
experiment: str,
run_name: str = None,
transition_to_stage: str = None,
version_alias: str = None,
run_name: Optional[str] = None,
run_id: Optional[str] = None,
transition_to_stage: Optional[Literal["staging", "production"]] = None,
parent_run_id: Optional[str] = None,
version_alias: Optional[str] = None,
metrics: Optional[Dict[str, float]] = None,
model_params: Optional[Dict[str, Any]] = None,
artifacts_to_log: Dict[str, str] = None,
):
) -> str:
"""Registers mlflow model
:param model: Machine learning model to register
:param mlflow_client: Mlflow client
:param model_name: Name of Mlflow model
:param experiment: Name of Mlflow experiment
:param run_name: Name of Mlflow run
:param run_name: Name of Mlflow run (only used if run_id is None)
:param run_id: Run id
:param parent_run_id: Parent run id
:param transition_to_stage: Whether to transition to stage
:param version_alias: Alias to assign to model
:param metrics: Metrics to log
:param model_params: Model hyperparameters to log
:param artifacts_to_log: Additional artifacts to log
:return: Run id of the newly created run for registering the model.
If run_id is provided, it will be the same as run_id
"""
assert transition_to_stage in [None, "Staging", "Production"]
assert transition_to_stage in [None, "staging", "production"]

mlflow.set_experiment(experiment)

Expand All @@ -98,7 +105,7 @@ def register_mlflow_model(
raise ValueError('Artifact names "model" and "config" are reserved for internal usage')
artifacts.update(artifacts_to_log)

with mlflow.start_run(nested=True, run_name=run_name):
with mlflow.start_run(nested=True, run_name=run_name, run_id=run_id, parent_run_id=parent_run_id) as run:
mlflow.pyfunc.log_model(
artifact_path="mlflow_model",
python_model=_MlflowMachineLearningModel(),
Expand All @@ -123,3 +130,5 @@ def register_mlflow_model(
stage=transition_to_stage,
model_version=version,
)

return run.info.run_id

0 comments on commit 9c7eb84

Please sign in to comment.