diff --git a/Makefile b/Makefile index a09c2d43f..e5047c5bc 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,9 @@ uml: ## Install documentation dependencies and generate UML diagrams pyreverse pymc_marketing/mmm -d docs/source/uml -f 'ALL' -o png -p mmm pyreverse pymc_marketing/clv -d docs/source/uml -f 'ALL' -o png -p clv +mlflow_server: ## Start MLflow server on port 5000 + mlflow server --backend-store-uri sqlite:///mlruns.db --default-artifact-root ./mlruns + ################################################################################# # Self Documenting Commands # diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index 60c7ffcce..da281a64e 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -304,6 +304,11 @@ def log_model_derived_info(model: Model) -> None: - The model representation (str). - The model coordinates (coords.json). + Parameters + ---------- + model : Model + The PyMC model object. + """ log_types_of_parameters(model) @@ -321,6 +326,7 @@ def log_model_derived_info(model: Model) -> None: def log_sample_diagnostics( idata: az.InferenceData, + tune: int | None = None, ) -> None: """Log sample diagnostics to MLflow. @@ -336,6 +342,14 @@ def log_sample_diagnostics( - The version of the inference library - The version of ArviZ + Parameters + ---------- + idata : az.InferenceData + The InferenceData object returned by the sampling method. + tune : int, optional + The number of tuning steps used in sampling. Derived from the + inference data if not provided. + """ if "posterior" not in idata: raise KeyError("InferenceData object does not contain the group posterior.") @@ -348,19 +362,28 @@ def log_sample_diagnostics( diverging = sample_stats["diverging"] + chains = posterior.sizes["chain"] + draws = posterior.sizes["draw"] + posterior_samples = chains * draws + + tuning_step = sample_stats.attrs.get("tuning_steps", tune) + if tuning_step is not None: + tuning_samples = tuning_step * chains + mlflow.log_param("tuning_steps", tuning_step) + mlflow.log_param("tuning_samples", tuning_samples) + total_divergences = diverging.sum().item() mlflow.log_metric("total_divergences", total_divergences) if sampling_time := sample_stats.attrs.get("sampling_time"): mlflow.log_metric("sampling_time", sampling_time) mlflow.log_metric( "time_per_draw", - sampling_time / (posterior.sizes["draw"] * posterior.sizes["chain"]), + sampling_time / posterior_samples, ) - if tuning_step := sample_stats.attrs.get("tuning_steps"): - mlflow.log_param("tuning_steps", tuning_step) - mlflow.log_param("draws", posterior.sizes["draw"]) - mlflow.log_param("chains", posterior.sizes["chain"]) + mlflow.log_param("draws", draws) + mlflow.log_param("chains", chains) + mlflow.log_param("posterior_samples", posterior_samples) if inference_library := posterior.attrs.get("inference_library"): mlflow.log_param("inference_library", inference_library) @@ -382,8 +405,7 @@ def log_inference_data( idata : az.InferenceData The InferenceData object returned by the sampling method. save_file : str | Path - The path to save the InferenceData object as a net - CDF file. + The path to save the InferenceData object as a netCDF file. """ idata.to_netcdf(str(save_file)) @@ -516,8 +538,11 @@ def new_sample(*args, **kwargs): mlflow.log_param("pymc_version", pm.__version__) mlflow.log_param("nuts_sampler", kwargs.get("nuts_sampler", "pymc")) + # Align with the default values in pymc.sample + tune = kwargs.get("tune", 1000) + if log_sampler_info: - log_sample_diagnostics(idata) + log_sample_diagnostics(idata, tune=tune) log_arviz_summary( idata, "summary.html", diff --git a/pyproject.toml b/pyproject.toml index 171a32e30..1c452ccc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ docs = [ "sphinx", "sphinxext-opengraph", "watermark", + "mlflow>=2.0.0", ] lint = ["mypy", "pandas-stubs", "pre-commit>=2.19.0", "ruff>=0.1.4"] test = [ diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index a52975101..8e7d1a25b 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -231,10 +231,13 @@ def metric_checks(metrics, nuts_sampler) -> None: def param_checks(params, draws: int, chains: int, tune: int, nuts_sampler: str) -> None: assert params["draws"] == str(draws) assert params["chains"] == str(chains) + assert params["posterior_samples"] == str(draws * chains) + if nuts_sampler not in ["numpyro", "blackjax"]: assert params["inference_library"] == nuts_sampler - if nuts_sampler not in ["numpyro", "nutpie", "blackjax"]: - assert params["tuning_steps"] == str(tune) + + assert params["tuning_steps"] == str(tune) + assert params["tuning_samples"] == str(tune * chains) assert params["pymc_marketing_version"] == __version__