diff --git a/.gitignore b/.gitignore index d6a66f4..f4f6dab 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ Manifest.toml /docs/build/ mlruns +coverage diff --git a/Project.toml b/Project.toml index 170fc46..1d37f9a 100644 --- a/Project.toml +++ b/Project.toml @@ -14,11 +14,11 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] FilePathsBase = "0.9" -HTTP = "0.9,1.2" +HTTP = "1.9" JSON = "0.21" ShowCases = "0.1" -URIs = "1" -julia = "1" +URIs = "1.0" +julia = "1.0" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/docs/make.jl b/docs/make.jl index 0c0ecc7..00e20ad 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,16 +11,16 @@ makedocs(; format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://juliaai.github.io/MLFlowClient.jl", - assets=String[], + assets=String[] ), pages=[ "Home" => "index.md", "Tutorial" => "tutorial.md", "Reference" => "reference.md" - ], + ] ) deploydocs(; repo="github.com/JuliaAI/MLFlowClient.jl", - devbranch="main", + devbranch="main" ) diff --git a/docs/src/reference.md b/docs/src/reference.md index b31d1c8..6fc873c 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -26,8 +26,9 @@ MLFlowArtifactDirInfo createexperiment getexperiment getorcreateexperiment -listexperiments deleteexperiment +searchexperiments +listexperiments ``` # Runs @@ -50,5 +51,8 @@ listartifacts mlfget mlfpost uri +generatefilterfromentity_type generatefilterfromparams +generatefilterfromattributes + ``` diff --git a/src/MLFlowClient.jl b/src/MLFlowClient.jl index 92642ef..04346fd 100644 --- a/src/MLFlowClient.jl +++ b/src/MLFlowClient.jl @@ -20,19 +20,25 @@ using JSON using ShowCases using FilePathsBase: AbstractPath -include("types.jl") +include("types/core.jl") export MLFlow, - MLFlowExperiment, + MLFlowExperiment + +include("types/runs.jl") +export MLFlowRunStatus, MLFlowRunInfo, - get_run_id, - MLFlowRunData, - get_params, MLFlowRunDataMetric, + MLFlowRunData, MLFlowRun, get_info, get_data, + get_run_id, + get_params + +include("types/artifacts.jl") +export MLFlowArtifactFileInfo, MLFlowArtifactDirInfo, get_path, @@ -41,6 +47,8 @@ export include("utils.jl") export generatefilterfromparams + generatefilterfromattributes + generatefilterfromentity_type include("experiments.jl") export @@ -48,7 +56,7 @@ export getexperiment, getorcreateexperiment, deleteexperiment, - listexperiments + searchexperiments include("runs.jl") export @@ -56,10 +64,12 @@ export getrun, updaterun, deleterun, - searchruns, + searchruns + +include("loggers.jl") +export logparam, logmetric, logartifact, listartifacts - end diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..424c33a --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,12 @@ +""" + listexperiments(mlf::MLFlow) + +Returns a list of MLFlow experiments. + +Deprecated (last MLFlow version: 1.30.1) in favor of [`searchexperiments`](@ref). +""" + +function listexperiments(mlf::MLFlow) +endpoint = "experiments/list" + mlfget(mlf, endpoint) +end diff --git a/src/experiments.jl b/src/experiments.jl index 0e10be9..650e79b 100644 --- a/src/experiments.jl +++ b/src/experiments.jl @@ -73,12 +73,12 @@ function getexperiment(mlf::MLFlow, experiment_name::String) end function _getexperimentbyid(mlf::MLFlow, experiment_id::Integer) endpoint = "experiments/get" - arguments = (:experiment_id => experiment_id, ) + arguments = (:experiment_id => experiment_id,) mlfget(mlf, endpoint; arguments...)["experiment"] end function _getexperimentbyname(mlf::MLFlow, experiment_name::String) endpoint = "experiments/get-by-name" - arguments = (:experiment_name => experiment_name, ) + arguments = (:experiment_name => experiment_name,) mlfget(mlf, endpoint; arguments...)["experiment"] end @@ -147,13 +147,59 @@ deleteexperiment(mlf::MLFlow, experiment::MLFlowExperiment) = deleteexperiment(mlf, experiment.experiment_id) """ - listexperiments(mlf::MLFlow) + searchexperiments(mlf::MLFlow) -Returns a list of MLFlow experiments. +Searches for experiments in an MLFlow instance. + +# Arguments +- `mlf`: [`MLFlow`](@ref) configuration. + +# Keywords +- `filter::String`: filter as defined in [MLFlow documentation](https://mlflow.org/docs/latest/rest-api.html#search-experiments) +- `filter_attributes::AbstractDict{K,V}`: if provided, `filter` is automatically generated based on `filter_attributes` using [`generatefilterfromattributes`](@ref). One can only provide either `filter` or `filter_attributes`, but not both. +- `run_view_type::String`: one of `ACTIVE_ONLY`, `DELETED_ONLY`, or `ALL`. +- `max_results::Integer`: 50,000 by default. +- `order_by::String`: as defined in [MLFlow documentation](https://mlflow.org/docs/latest/rest-api.html#search-experiments) +- `page_token::String`: paging functionality, handled automatically. Not meant to be passed by the user. + +# Returns +- vector of [`MLFlowExperiment`](@ref) experiments that were found in the MLFlow instance -TODO: not yet entirely implemented """ -function listexperiments(mlf::MLFlow) - endpoint = "experiments/list" - mlfget(mlf, endpoint) +function searchexperiments(mlf::MLFlow; + filter::String="", + filter_attributes::AbstractDict{K,V}=Dict{}(), + run_view_type::String="ACTIVE_ONLY", + max_results::Int64=50000, + order_by::AbstractVector{<:String}=["attribute.last_update_time"], + page_token::String="" +) where {K,V} + endpoint = "experiments/search" + run_view_type ∈ ["ACTIVE_ONLY", "DELETED_ONLY", "ALL"] || error("Unsupported run_view_type = $run_view_type") + + if length(filter_attributes) > 0 && length(filter) > 0 + error("Cannot specify both filter and filter_attributes") + end + + if length(filter_attributes) > 0 + filter = generatefilterfromattributes(filter_attributes) + end + + kwargs = (; filter, run_view_type, max_results, order_by) + if !isempty(page_token) + kwargs = (; kwargs..., page_token=page_token) + end + + result = mlfpost(mlf, endpoint; kwargs...) + haskey(result, "experiments") || return MLFlowExperiment[] + + experiments = map(x -> MLFlowExperiment(x), result["experiments"]) + + if haskey(result, "next_page_token") && !isempty(result["next_page_token"]) + kwargs = (; filter, run_view_type, max_results, order_by, page_token=result["next_page_token"]) + next_experiments = searchexperiments(mlf; kwargs...) + return vcat(experiments, next_experiments) + end + + experiments end diff --git a/src/loggers.jl b/src/loggers.jl new file mode 100644 index 0000000..c7b6096 --- /dev/null +++ b/src/loggers.jl @@ -0,0 +1,196 @@ +""" + logparam(mlf::MLFlow, run, key, value) + logparam(mlf::MLFlow, run, kv) + +Associates a key/value pair of parameters to the particular run. + +# Arguments +- `mlf`: [`MLFlow`](@ref) configuration. +- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref), or `String`. +- `key`: parameter key (name). Automatically converted to string before sending to MLFlow because this is the only type that MLFlow supports. +- `value`: parameter value. Automatically converted to string before sending to MLFlow because this is the only type that MLFlow supports. + +One could also specify `kv::Dict` instead of separate `key` and `value` arguments. +""" +function logparam(mlf::MLFlow, run_id::String, key, value) + endpoint = "runs/log-parameter" + mlfpost(mlf, endpoint; run_id=run_id, key=string(key), value=string(value)) +end +logparam(mlf::MLFlow, run_info::MLFlowRunInfo, key, value) = + logparam(mlf, run_info.run_id, key, value) +logparam(mlf::MLFlow, run::MLFlowRun, key, value) = + logparam(mlf, run.info, key, value) +function logparam(mlf::MLFlow, run::Union{String,MLFlowRun,MLFlowRunInfo}, kv) + for (k, v) in kv + logparam(mlf, run, k, v) + end +end + +""" + logmetric(mlf::MLFlow, run, key, value::T; timestamp, step) where T<:Real + logmetric(mlf::MLFlow, run, key, values::AbstractArray{T}; timestamp, step) where T<:Real + +Logs a metric value (or values) against a particular run. + +# Arguments +- `mlf`: [`MLFlow`](@ref) configuration. +- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref), or `String` +- `key`: metric name. +- `value`: metric value, must be numeric. + +# Keywords +- `timestamp`: if provided, must be a UNIX timestamp in milliseconds. By default, set to current time. +- `step`: step at which the metric value has been taken. +""" +function logmetric(mlf::MLFlow, run_id::String, key, value::T; timestamp=missing, step=missing) where {T<:Real} + endpoint = "runs/log-metric" + if ismissing(timestamp) + timestamp = Int(trunc(datetime2unix(now(UTC)) * 1000)) + end + mlfpost(mlf, endpoint; run_id=run_id, key=key, value=value, timestamp=timestamp, step=step) +end +logmetric(mlf::MLFlow, run_info::MLFlowRunInfo, key, value::T; timestamp=missing, step=missing) where {T<:Real} = + logmetric(mlf::MLFlow, run_info.run_id, key, value; timestamp=timestamp, step=step) +logmetric(mlf::MLFlow, run::MLFlowRun, key, value::T; timestamp=missing, step=missing) where {T<:Real} = + logmetric(mlf, run.info, key, value; timestamp=timestamp, step=step) + +function logmetric(mlf::MLFlow, run::Union{String,MLFlowRun,MLFlowRunInfo}, key, values::AbstractArray{T}; timestamp=missing, step=missing) where {T<:Real} + for v in values + logmetric(mlf, run, key, v; timestamp=timestamp, step=step) + end +end + + +""" + logartifact(mlf::MLFlow, run, basefilename, data) + +Stores an artifact (file) in the run's artifact location. + +!!! note + Assumes that artifact_uri is mapped to a local directory. + At the moment, this only works if both MLFlow and the client are running on the same host or they map a directory that leads to the same location over NFS, for example. + +# Arguments +- `mlf::MLFlow`: [`MLFlow`](@ref) onfiguration. Currently not used, but when this method is extended to support `S3`, information from `mlf` will be needed. +- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref) or `String`. +- `basefilename`: name of the file to be written. +- `data`: artifact content, an object that can be written directly to a file handle. + +# Throws +- an `ErrorException` if an exception occurs during writing artifact. + +# Returns +path of the artifact that was created. +""" +function logartifact(mlf::MLFlow, run_id::AbstractString, basefilename::AbstractString, data) + mlflowrun = getrun(mlf, run_id) + artifact_uri = mlflowrun.info.artifact_uri + mkpath(artifact_uri) + filepath = joinpath(artifact_uri, basefilename) + try + f = open(filepath, "w") + write(f, data) + close(f) + catch e + error("Unable to create artifact $(filepath): $e") + end + filepath +end +logartifact(mlf::MLFlow, run::MLFlowRun, basefilename::AbstractString, data) = + logartifact(mlf, run.info, basefilename, data) +logartifact(mlf::MLFlow, run_info::MLFlowRunInfo, basefilename::AbstractString, data) = + logartifact(mlf, run_info.run_id, basefilename, data) + +""" + logartifact(mlf::MLFlow, run, filepath) + +Stores an artifact (file) in the run's artifact location. +The name of the artifact is calculated using `basename(filepath)`. + +Dispatches on `logartifact(mlf::MLFlow, run, basefilename, data)` where `data` is the contents of `filepath`. + +# Throws +- an `ErrorException` if `filepath` does not exist. +- an exception if such occurs while trying to read the contents of `filepath`. + +""" +function logartifact(mlf::MLFlow, run_id::AbstractString, filepath::Union{AbstractPath,AbstractString}) + isfile(filepath) || error("File $filepath does not exist.") + try + f = open(filepath, "r") + data = read(f) + close(f) + return logartifact(mlf, run_id, basename(filepath), data) + catch e + throw(e) + finally + if @isdefined f + close(f) + end + end +end +logartifact(mlf::MLFlow, run::MLFlowRun, filepath::Union{AbstractPath,AbstractString}) = + logartifact(mlf, run.info, filepath) +logartifact(mlf::MLFlow, run_info::MLFlowRunInfo, filepath::Union{AbstractPath,AbstractString}) = + logartifact(mlf, run_info.run_id, filepath) + +""" + listartifacts(mlf::MLFlow, run) + +Lists the artifacts associated with an experiment run. +According to [MLFlow documentation](https://mlflow.org/docs/latest/rest-api.html#list-artifacts), this API endpoint should return paged results, similar to [`searchruns`](@ref). +However, after some experimentation, this doesn't seem to be the case. Therefore, the paging functionality is not implemented here. + +# Arguments +- `mlf::MLFlow`: [`MLFlow`](@ref) onfiguration. Currently not used, but when this method is extended to support `S3`, information from `mlf` will be needed. +- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref) or `String`. + +# Keywords +- `path::String`: path of a directory within the artifact location. If set, returns the contents of the directory. By default, this is the root directory of the artifacts. +- `maxdepth::Int64`: depth of listing. Default is 1. This will only return the files/directories in the current `path`. To return all artifacts files and directories, use `maxdepth=-1`. + +# Returns +A vector of `Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}`. +""" +function listartifacts(mlf::MLFlow, run_id::String; path::String="", maxdepth::Int64=1) + endpoint = "artifacts/list" + kwargs = ( + run_id=run_id, + ) + kwargs = (; kwargs..., path=path) + httpresult = mlfget(mlf, endpoint; kwargs...) + "files" ∈ keys(httpresult) || return Vector{Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}}() + "root_uri" ∈ keys(httpresult) || error("Malformed response from MLFlow REST API.") + root_uri = httpresult["root_uri"] + result = Vector{Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}}() + maxdepth == 0 && return result + + for resultentry ∈ httpresult["files"] + if resultentry["is_dir"] == false + filepath = joinpath(root_uri, resultentry["path"]) + file_size = resultentry["file_size"] + if typeof(file_size) <: Int + filesize = file_size + else + filesize = parse(Int, file_size) + end + push!(result, MLFlowArtifactFileInfo(filepath, filesize)) + elseif resultentry["is_dir"] == true + dirpath = joinpath(root_uri, resultentry["path"]) + push!(result, MLFlowArtifactDirInfo(dirpath)) + if maxdepth != 0 + nextdepthresult = listartifacts(mlf, run_id, path=resultentry["path"], maxdepth=maxdepth - 1) + result = vcat(result, nextdepthresult) + end + else + isdirval = resultentry["is_dir"] + @warn "Malformed response from MLFlow REST API is_dir=$isdirval - skipping" + continue + end + end + result +end +listartifacts(mlf::MLFlow, run::MLFlowRun; kwargs...) = + listartifacts(mlf, run.info.run_id; kwargs...) +listartifacts(mlf::MLFlow, run_info::MLFlowRunInfo; kwargs...) = + listartifacts(mlf, run_info.run_id; kwargs...) diff --git a/src/runs.jl b/src/runs.jl index abf662d..9115680 100644 --- a/src/runs.jl +++ b/src/runs.jl @@ -132,13 +132,13 @@ Searches for runs in an experiment. """ function searchruns(mlf::MLFlow, experiment_ids::AbstractVector{<:Integer}; - filter::String="", - filter_params::AbstractDict{K,V}=Dict{}(), - run_view_type::String="ACTIVE_ONLY", - max_results::Int64=50000, - order_by::AbstractVector{<:String}=["attribute.end_time"], - page_token::String="" - ) where {K,V} + filter::String="", + filter_params::AbstractDict{K,V}=Dict{}(), + run_view_type::String="ACTIVE_ONLY", + max_results::Int64=50000, + order_by::AbstractVector{<:String}=["attribute.end_time"], + page_token::String="" +) where {K,V} endpoint = "runs/search" run_view_type ∈ ["ACTIVE_ONLY", "DELETED_ONLY", "ALL"] || error("Unsupported run_view_type = $run_view_type") @@ -175,8 +175,8 @@ function searchruns(mlf::MLFlow, experiment_ids::AbstractVector{<:Integer}; order_by=order_by, page_token=result["next_page_token"] ) - nextruns = searchruns(mlf, experiment_ids; kwargs...) - return vcat(runs, nextruns) + next_runs = searchruns(mlf, experiment_ids; kwargs...) + return vcat(runs, next_runs) end runs @@ -186,202 +186,4 @@ searchruns(mlf::MLFlow, experiment_id::Integer; kwargs...) = searchruns(mlf::MLFlow, exp::MLFlowExperiment; kwargs...) = searchruns(mlf, exp.experiment_id; kwargs...) searchruns(mlf::MLFlow, exps::AbstractVector{MLFlowExperiment}; kwargs...) = - searchruns(mlf, [getfield.(exps, :experiment_id)]; kwargs...) - - -""" - logparam(mlf::MLFlow, run, key, value) - logparam(mlf::MLFlow, run, kv) - -Associates a key/value pair of parameters to the particular run. - -# Arguments -- `mlf`: [`MLFlow`](@ref) configuration. -- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref), or `String`. -- `key`: parameter key (name). Automatically converted to string before sending to MLFlow because this is the only type that MLFlow supports. -- `value`: parameter value. Automatically converted to string before sending to MLFlow because this is the only type that MLFlow supports. - -One could also specify `kv::Dict` instead of separate `key` and `value` arguments. -""" -function logparam(mlf::MLFlow, run_id::String, key, value) - endpoint = "runs/log-parameter" - mlfpost(mlf, endpoint; run_id=run_id, key=string(key), value=string(value)) -end -logparam(mlf::MLFlow, run_info::MLFlowRunInfo, key, value) = - logparam(mlf, run_info.run_id, key, value) -logparam(mlf::MLFlow, run::MLFlowRun, key, value) = - logparam(mlf, run.info, key, value) -function logparam(mlf::MLFlow, run::Union{String,MLFlowRun,MLFlowRunInfo}, kv) - for (k, v) in kv - logparam(mlf, run, k, v) - end -end - -""" - logmetric(mlf::MLFlow, run, key, value::T; timestamp, step) where T<:Real - logmetric(mlf::MLFlow, run, key, values::AbstractArray{T}; timestamp, step) where T<:Real - -Logs a metric value (or values) against a particular run. - -# Arguments -- `mlf`: [`MLFlow`](@ref) configuration. -- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref), or `String` -- `key`: metric name. -- `value`: metric value, must be numeric. - -# Keywords -- `timestamp`: if provided, must be a UNIX timestamp in milliseconds. By default, set to current time. -- `step`: step at which the metric value has been taken. -""" -function logmetric(mlf::MLFlow, run_id::String, key, value::T; timestamp=missing, step=missing) where T<:Real - endpoint = "runs/log-metric" - if ismissing(timestamp) - timestamp = Int(trunc(datetime2unix(now(UTC)) * 1000)) - end - mlfpost(mlf, endpoint; run_id=run_id, key=key, value=value, timestamp=timestamp, step=step) -end -logmetric(mlf::MLFlow, run_info::MLFlowRunInfo, key, value::T; timestamp=missing, step=missing) where T<:Real = - logmetric(mlf::MLFlow, run_info.run_id, key, value; timestamp=timestamp, step=step) -logmetric(mlf::MLFlow, run::MLFlowRun, key, value::T; timestamp=missing, step=missing) where T<:Real = - logmetric(mlf, run.info, key, value; timestamp=timestamp, step=step) - -function logmetric(mlf::MLFlow, run::Union{String,MLFlowRun,MLFlowRunInfo}, key, values::AbstractArray{T}; timestamp=missing, step=missing) where T<:Real - for v in values - logmetric(mlf, run, key, v; timestamp=timestamp, step=step) - end -end - - -""" - logartifact(mlf::MLFlow, run, basefilename, data) - -Stores an artifact (file) in the run's artifact location. - -!!! note - Assumes that artifact_uri is mapped to a local directory. - At the moment, this only works if both MLFlow and the client are running on the same host or they map a directory that leads to the same location over NFS, for example. - -# Arguments -- `mlf::MLFlow`: [`MLFlow`](@ref) onfiguration. Currently not used, but when this method is extended to support `S3`, information from `mlf` will be needed. -- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref) or `String`. -- `basefilename`: name of the file to be written. -- `data`: artifact content, an object that can be written directly to a file handle. - -# Throws -- an `ErrorException` if an exception occurs during writing artifact. - -# Returns -path of the artifact that was created. -""" -function logartifact(mlf::MLFlow, run_id::AbstractString, basefilename::AbstractString, data) - mlflowrun = getrun(mlf, run_id) - artifact_uri = mlflowrun.info.artifact_uri - mkpath(artifact_uri) - filepath = joinpath(artifact_uri, basefilename) - try - f = open(filepath, "w") - write(f, data) - close(f) - catch e - error("Unable to create artifact $(filepath): $e") - end - filepath -end -logartifact(mlf::MLFlow, run::MLFlowRun, basefilename::AbstractString, data) = - logartifact(mlf, run.info, basefilename, data) -logartifact(mlf::MLFlow, run_info::MLFlowRunInfo, basefilename::AbstractString, data) = - logartifact(mlf, run_info.run_id, basefilename, data) - -""" - logartifact(mlf::MLFlow, run, filepath) - -Stores an artifact (file) in the run's artifact location. -The name of the artifact is calculated using `basename(filepath)`. - -Dispatches on `logartifact(mlf::MLFlow, run, basefilename, data)` where `data` is the contents of `filepath`. - -# Throws -- an `ErrorException` if `filepath` does not exist. -- an exception if such occurs while trying to read the contents of `filepath`. - -""" -function logartifact(mlf::MLFlow, run_id::AbstractString, filepath::Union{AbstractPath,AbstractString}) - isfile(filepath) || error("File $filepath does not exist.") - try - f = open(filepath, "r") - data = read(f) - close(f) - return logartifact(mlf, run_id, basename(filepath), data) - catch e - throw(e) - finally - if @isdefined f - close(f) - end - end -end -logartifact(mlf::MLFlow, run::MLFlowRun, filepath::Union{AbstractPath,AbstractString}) = - logartifact(mlf, run.info, filepath) -logartifact(mlf::MLFlow, run_info::MLFlowRunInfo, filepath::Union{AbstractPath,AbstractString}) = - logartifact(mlf, run_info.run_id, filepath) - -""" - listartifacts(mlf::MLFlow, run) - -Lists the artifacts associated with an experiment run. -According to [MLFlow documentation](https://mlflow.org/docs/latest/rest-api.html#list-artifacts), this API endpoint should return paged results, similar to [`searchruns`](@ref). -However, after some experimentation, this doesn't seem to be the case. Therefore, the paging functionality is not implemented here. - -# Arguments -- `mlf::MLFlow`: [`MLFlow`](@ref) onfiguration. Currently not used, but when this method is extended to support `S3`, information from `mlf` will be needed. -- `run`: one of [`MLFlowRun`](@ref), [`MLFlowRunInfo`](@ref) or `String`. - -# Keywords -- `path::String`: path of a directory within the artifact location. If set, returns the contents of the directory. By default, this is the root directory of the artifacts. -- `maxdepth::Int64`: depth of listing. Default is 1. This will only return the files/directories in the current `path`. To return all artifacts files and directories, use `maxdepth=-1`. - -# Returns -A vector of `Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}`. -""" -function listartifacts(mlf::MLFlow, run_id::String; path::String="", maxdepth::Int64=1) - endpoint = "artifacts/list" - kwargs = ( - run_id=run_id, - ) - kwargs = (; kwargs..., path=path) - httpresult = mlfget(mlf, endpoint; kwargs...) - "files" ∈ keys(httpresult) || return Vector{Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}}() - "root_uri" ∈ keys(httpresult) || error("Malformed response from MLFlow REST API.") - root_uri = httpresult["root_uri"] - result = Vector{Union{MLFlowArtifactFileInfo,MLFlowArtifactDirInfo}}() - maxdepth == 0 && return result - - for resultentry ∈ httpresult["files"] - if resultentry["is_dir"] == false - filepath = joinpath(root_uri, resultentry["path"]) - file_size = resultentry["file_size"] - if typeof(file_size) <: Int - filesize = file_size - else - filesize = parse(Int, file_size) - end - push!(result, MLFlowArtifactFileInfo(filepath, filesize)) - elseif resultentry["is_dir"] == true - dirpath = joinpath(root_uri, resultentry["path"]) - push!(result, MLFlowArtifactDirInfo(dirpath)) - if maxdepth != 0 - nextdepthresult = listartifacts(mlf, run_id, path=resultentry["path"], maxdepth=maxdepth-1) - result = vcat(result, nextdepthresult) - end - else - isdirval = resultentry["is_dir"] - @warn "Malformed response from MLFlow REST API is_dir=$isdirval - skipping" - continue - end - end - result -end -listartifacts(mlf::MLFlow, run::MLFlowRun; kwargs...) = - listartifacts(mlf, run.info.run_id; kwargs...) -listartifacts(mlf::MLFlow, run_info::MLFlowRunInfo; kwargs...) = - listartifacts(mlf, run_info.run_id; kwargs...) + searchruns(mlf, getfield.(exps, :experiment_id); kwargs...) diff --git a/src/types/artifacts.jl b/src/types/artifacts.jl new file mode 100644 index 0000000..1b7b984 --- /dev/null +++ b/src/types/artifacts.jl @@ -0,0 +1,31 @@ +""" + MLFlowArtifactFileInfo + +Metadata of a single artifact file -- result of [`listartifacts`](@ref). + +# Fields +- `filepath::String`: File path, including the root artifact directory of a run. +- `filesize::Int64`: Size in bytes. +""" +struct MLFlowArtifactFileInfo + filepath::String + filesize::Int64 +end +Base.show(io::IO, t::MLFlowArtifactFileInfo) = show(io, ShowCase(t, new_lines=true)) +get_path(mlfafi::MLFlowArtifactFileInfo) = mlfafi.filepath +get_size(mlfafi::MLFlowArtifactFileInfo) = mlfafi.filesize + +""" + MLFlowArtifactDirInfo + +Metadata of a single artifact directory -- result of [`listartifacts`](@ref). + +# Fields +- `dirpath::String`: Directory path, including the root artifact directory of a run. +""" +struct MLFlowArtifactDirInfo + dirpath::String +end +Base.show(io::IO, t::MLFlowArtifactDirInfo) = show(io, ShowCase(t, new_lines=true)) +get_path(mlfadi::MLFlowArtifactDirInfo) = mlfadi.dirpath +get_size(mlfadi::MLFlowArtifactDirInfo) = 0 diff --git a/src/types/core.jl b/src/types/core.jl new file mode 100644 index 0000000..9b1b577 --- /dev/null +++ b/src/types/core.jl @@ -0,0 +1,70 @@ +""" + MLFlow + +Base type which defines location and version for MLFlow API service. + +# Fields +- `baseuri::String`: base MLFlow tracking URI, e.g. `http://localhost:5000` +- `apiversion`: used API version, e.g. `2.0` +- `headers`: HTTP headers to be provided with the REST API requests (useful for authetication tokens) + +# Constructors + +- `MLFlow(baseuri; apiversion=2.0,headers=Dict())` +- `MLFlow()` - defaults to `MLFlow("http://localhost:5000")` + +# Examples + +```@example +mlf = MLFlow() +``` + +```@example +remote_url="https://.cloud.databricks.com"; # address of your remote server +mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer ")) +``` + +""" +struct MLFlow + baseuri::String + apiversion + headers::Dict +end +MLFlow(baseuri; apiversion=2.0,headers=Dict()) = MLFlow(baseuri, apiversion,headers) +MLFlow() = MLFlow("http://localhost:5000", 2.0, Dict()) +Base.show(io::IO, t::MLFlow) = show(io, ShowCase(t, [:baseuri,:apiversion], new_lines=true)) + +""" + MLFlowExperiment + +Represents an MLFlow experiment. + +# Fields +- `name::String`: experiment name. +- `lifecycle_stage::String`: life cycle stage, one of ["active", "deleted"] +- `experiment_id::Integer`: experiment identifier. +- `tags::Any`: list of tags. +- `artifact_location::String`: where are experiment artifacts stored. + +# Constructors + +- `MLFlowExperiment(name, lifecycle_stage, experiment_id, tags, artifact_location)` +- `MLFlowExperiment(exp::Dict{String,Any})` + +""" +struct MLFlowExperiment + name::String + lifecycle_stage::String + experiment_id::Integer + tags::Any + artifact_location::String +end +function MLFlowExperiment(exp::Dict{String,Any}) + name = get(exp, "name", missing) + lifecycle_stage = get(exp, "lifecycle_stage", missing) + experiment_id = parse(Int, get(exp, "experiment_id", missing)) + tags = get(exp, "tags", missing) + artifact_location = get(exp, "artifact_location", missing) + MLFlowExperiment(name, lifecycle_stage, experiment_id, tags, artifact_location) +end +Base.show(io::IO, t::MLFlowExperiment) = show(io, ShowCase(t, new_lines=true)) diff --git a/src/types.jl b/src/types/runs.jl similarity index 64% rename from src/types.jl rename to src/types/runs.jl index a1d64bb..3b60739 100644 --- a/src/types.jl +++ b/src/types/runs.jl @@ -1,74 +1,3 @@ -""" - MLFlow - -Base type which defines location and version for MLFlow API service. - -# Fields -- `baseuri::String`: base MLFlow tracking URI, e.g. `http://localhost:5000` -- `apiversion`: used API version, e.g. `2.0` -- `headers`: HTTP headers to be provided with the REST API requests (useful for authetication tokens) - -# Constructors - -- `MLFlow(baseuri; apiversion=2.0,headers=Dict())` -- `MLFlow()` - defaults to `MLFlow("http://localhost:5000")` - -# Examples - -```@example -mlf = MLFlow() -``` - -```@example -remote_url="https://.cloud.databricks.com"; # address of your remote server -mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer ")) -``` - -""" -struct MLFlow - baseuri::String - apiversion - headers::Dict -end -MLFlow(baseuri; apiversion=2.0,headers=Dict()) = MLFlow(baseuri, apiversion,headers) -MLFlow() = MLFlow("http://localhost:5000", 2.0, Dict()) -Base.show(io::IO, t::MLFlow) = show(io, ShowCase(t, [:baseuri,:apiversion], new_lines=true)) - -""" - MLFlowExperiment - -Represents an MLFlow experiment. - -# Fields -- `name::String`: experiment name. -- `lifecycle_stage::String`: life cycle stage, one of ["active", "deleted"] -- `experiment_id::Integer`: experiment identifier. -- `tags::Any`: list of tags. -- `artifact_location::String`: where are experiment artifacts stored. - -# Constructors - -- `MLFlowExperiment(name, lifecycle_stage, experiment_id, tags, artifact_location)` -- `MLFlowExperiment(exp::Dict{String,Any})` - -""" -struct MLFlowExperiment - name::String - lifecycle_stage::String - experiment_id::Integer - tags::Any - artifact_location::String -end -function MLFlowExperiment(exp::Dict{String,Any}) - name = get(exp, "name", missing) - lifecycle_stage = get(exp, "lifecycle_stage", missing) - experiment_id = parse(Int, get(exp, "experiment_id", missing)) - tags = get(exp, "tags", missing) - artifact_location = get(exp, "artifact_location", missing) - MLFlowExperiment(name, lifecycle_stage, experiment_id, tags, artifact_location) -end -Base.show(io::IO, t::MLFlowExperiment) = show(io, ShowCase(t, new_lines=true)) - """ MLFlowRunStatus @@ -261,35 +190,3 @@ get_info(run::MLFlowRun) = run.info get_data(run::MLFlowRun) = run.data get_run_id(run::MLFlowRun) = get_run_id(run.info) get_params(run::MLFlowRun) = get_params(run.data) - -""" - MLFlowArtifactFileInfo - -Metadata of a single artifact file -- result of [`listartifacts`](@ref). - -# Fields -- `filepath::String`: File path, including the root artifact directory of a run. -- `filesize::Int64`: Size in bytes. -""" -struct MLFlowArtifactFileInfo - filepath::String - filesize::Int64 -end -Base.show(io::IO, t::MLFlowArtifactFileInfo) = show(io, ShowCase(t, new_lines=true)) -get_path(mlfafi::MLFlowArtifactFileInfo) = mlfafi.filepath -get_size(mlfafi::MLFlowArtifactFileInfo) = mlfafi.filesize - -""" - MLFlowArtifactDirInfo - -Metadata of a single artifact directory -- result of [`listartifacts`](@ref). - -# Fields -- `dirpath::String`: Directory path, including the root artifact directory of a run. -""" -struct MLFlowArtifactDirInfo - dirpath::String -end -Base.show(io::IO, t::MLFlowArtifactDirInfo) = show(io, ShowCase(t, new_lines=true)) -get_path(mlfadi::MLFlowArtifactDirInfo) = mlfadi.dirpath -get_size(mlfadi::MLFlowArtifactDirInfo) = 0 diff --git a/src/utils.jl b/src/utils.jl index 4cb215e..8874c04 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,7 +24,7 @@ Retrieves HTTP headers based on `mlf` and merges with user-provided `custom_head headers(mlf,Dict("Content-Type"=>"application/json")) ``` """ -headers(mlf::MLFlow,custom_headers::AbstractDict)=merge(mlf.headers, custom_headers) +headers(mlf::MLFlow, custom_headers::AbstractDict) = merge(mlf.headers, custom_headers) """ mlfget(mlf, endpoint; kwargs...) @@ -33,7 +33,7 @@ Performs a HTTP GET to a specifid endpoint. kwargs are turned into GET params. """ function mlfget(mlf, endpoint; kwargs...) apiuri = uri(mlf, endpoint, kwargs) - apiheaders = headers(mlf,Dict("Content-Type"=>"application/json")) + apiheaders = headers(mlf, Dict("Content-Type" => "application/json")) try response = HTTP.get(apiuri, apiheaders) return JSON.parse(String(response.body)) @@ -49,7 +49,7 @@ Performs a HTTP POST to the specified endpoint. kwargs are converted to JSON and """ function mlfpost(mlf, endpoint; kwargs...) apiuri = uri(mlf, endpoint) - apiheaders = headers(mlf,Dict("Content-Type"=>"application/json")) + apiheaders = headers(mlf, Dict("Content-Type" => "application/json")) body = JSON.json(kwargs) try response = HTTP.post(apiuri, apiheaders, body) @@ -60,12 +60,13 @@ function mlfpost(mlf, endpoint; kwargs...) end """ - generatefilterfromparams(filter_params::AbstractDict{K,V}) where {K,V} + generatefilterfromentity_type(filter_params::AbstractDict{K,V}, entity_type::String) where {K,V} -Generates a `filter` string from `filter_params` dictionary. +Generates a `filter` string from `filter_params` dictionary and `entity_type`. # Arguments - `filter_params`: dictionary to use for filter generation. +- `entity_type`: entity type to use for filter generation. # Returns A string that can be passed as `filter` to [`searchruns`](@ref). @@ -73,12 +74,14 @@ A string that can be passed as `filter` to [`searchruns`](@ref). # Examples ```@example -generatefilterfromparams(Dict("paramkey1" => "paramvalue1", "paramkey2" => "paramvalue2")) +generatefilterfromentity_type(Dict("paramkey1" => "paramvalue1", "paramkey2" => "paramvalue2"), "param") ``` """ -function generatefilterfromparams(filter_params::AbstractDict{K,V}) where {K,V} +function generatefilterfromentity_type(filter_params::AbstractDict{K,V}, entity_type::String) where {K,V} length(filter_params) > 0 || return "" # NOTE: may have issues with escaping. - filters = ["param.\"$(k)\" = \"$(v)\"" for(k, v) ∈ filter_params] + filters = ["$(entity_type).\"$(k)\" = \"$(v)\"" for (k, v) ∈ filter_params] join(filters, " and ") end +generatefilterfromparams(filter_params::AbstractDict{K,V}) where {K,V} = generatefilterfromentity_type(filter_params, "param") +generatefilterfromattributes(filter_attributes::AbstractDict{K,V}) where {K,V} = generatefilterfromentity_type(filter_attributes, "attribute") diff --git a/test/test_base.jl b/test/base.jl similarity index 100% rename from test/test_base.jl rename to test/base.jl diff --git a/test/runtests.jl b/test/runtests.jl index 60826d1..e478e47 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1 +1,5 @@ -include("test_functional.jl") +include("base.jl") + +include("test_experiments.jl") +include("test_runs.jl") +include("test_loggers.jl") diff --git a/test/test_experiments.jl b/test/test_experiments.jl new file mode 100644 index 0000000..03ba028 --- /dev/null +++ b/test/test_experiments.jl @@ -0,0 +1,87 @@ +@testset "createexperiment" begin + @ensuremlf + exp = createexperiment(mlf) + + @test isa(exp, MLFlowExperiment) + deleteexperiment(mlf, exp) +end + +@testset verbose=true "getexperiment" begin + @ensuremlf + exp = createexperiment(mlf) + experiment = getexperiment(mlf, exp.experiment_id) + + @testset "getexperiment_by_experiment_id" begin + @test isa(experiment, MLFlowExperiment) + @test experiment.experiment_id == exp.experiment_id + end + + @testset "getexperiment_by_experiment_name" begin + experiment_by_name = getexperiment(mlf, exp.name) + @test isa(experiment_by_name, MLFlowExperiment) + @test experiment_by_name.experiment_id == exp.experiment_id + end + + @testset "getexperiment_not_found" begin + @test isa(getexperiment(mlf, 123), Missing) + end + deleteexperiment(mlf, exp) +end + +@testset "getorcreateexperiment" begin + @ensuremlf + expname = "getorcreate" + artifact_location = "test$(expname)" + e = getorcreateexperiment(mlf, expname; artifact_location=artifact_location) + @test isa(e, MLFlowExperiment) + + ee = getorcreateexperiment(mlf, expname) + @test isa(ee, MLFlowExperiment) + @test e === ee + @test occursin(artifact_location, e.artifact_location) + deleteexperiment(mlf, ee) +end + +@testset "deleteexperiment" begin + @ensuremlf + exp = createexperiment(mlf) + + @test deleteexperiment(mlf, exp) + # deleting again to test if the experiment is already deleted + @test deleteexperiment(mlf, exp) +end + +@testset verbose=true "searchexperiments" begin + @ensuremlf + n_experiments = 3 + for i in 2:n_experiments + createexperiment(mlf) + end + createexperiment(mlf; name="test") + experiments = searchexperiments(mlf) + + @testset "searchexperiments_get_all" begin + @test length(experiments) == (n_experiments + 1) # Adding one for the default experiment + end + + @testset "searchexperiments_by_filter" begin + experiments_by_filter = searchexperiments(mlf; filter="name=\"test\"") + @test length(experiments_by_filter) == 1 + @test experiments_by_filter[1].name == "test" + end + + @testset "searchexperiments_by_filter_attributes" begin + experiments_by_filter = searchexperiments(mlf; filter_attributes=Dict("name" => "test")) + @test length(experiments_by_filter) == 1 + @test experiments_by_filter[1].name == "test" + end + + @testset "searchexperiments_filter_exception" begin + @test_throws ErrorException searchexperiments(mlf; filter="test", filter_attributes=Dict("test" => "test")) + end + + popfirst!(experiments) # removing the default experiment (it can't be deleted) + for e in experiments + deleteexperiment(mlf, e) + end +end diff --git a/test/test_functional.jl b/test/test_functional.jl index ed928bd..6eff632 100644 --- a/test/test_functional.jl +++ b/test/test_functional.jl @@ -1,5 +1,3 @@ -include("test_base.jl") - @testset "MLFlow" begin mlf = MLFlow() @test mlf.baseuri == "http://localhost:5000" @@ -9,8 +7,8 @@ include("test_base.jl") @test mlf.baseuri == "https://localhost:5001" @test mlf.apiversion == 3.0 @test mlf.headers == Dict() - let custom_headers=Dict("Authorization"=>"Bearer EMPTY") - mlf = MLFlow("https://localhost:5001", apiversion=3.0,headers=custom_headers) + let custom_headers = Dict("Authorization" => "Bearer EMPTY") + mlf = MLFlow("https://localhost:5001", apiversion=3.0, headers=custom_headers) @test mlf.baseuri == "https://localhost:5001" @test mlf.apiversion == 3.0 @test mlf.headers == custom_headers @@ -19,17 +17,17 @@ end # test that sensitive fields are not displayed by show() @testset "MLFLow/show" begin - let io=IOBuffer(), - secret_token="SECRET" + let io = IOBuffer(), + secret_token = "SECRET" - custom_headers=Dict("Authorization"=>"Bearer $secret_token") - mlf = MLFlow("https://localhost:5001", apiversion=3.0,headers=custom_headers) + custom_headers = Dict("Authorization" => "Bearer $secret_token") + mlf = MLFlow("https://localhost:5001", apiversion=3.0, headers=custom_headers) @test mlf.baseuri == "https://localhost:5001" @test mlf.apiversion == 3.0 @test mlf.headers == custom_headers - show(io,mlf) - show_output=String(take!(io)) - @test !(occursin(secret_token,show_output)) + show(io, mlf) + show_output = String(take!(io)) + @test !(occursin(secret_token, show_output)) end end @@ -43,115 +41,13 @@ end end let baseuri = "http://localhost:5001", auth_headers = Dict("Authorization" => "Bearer 123456"), custom_headers = Dict("Content-Type" => "application/json") + mlf = MLFlow(baseuri; headers=auth_headers) apiheaders = headers(mlf, custom_headers) @test apiheaders == Dict("Authorization" => "Bearer 123456", "Content-Type" => "application/json") end end -@testset "createexperiment" begin - @ensuremlf - exp = createexperiment(mlf) - @test isa(exp, MLFlowExperiment) - @test deleteexperiment(mlf, exp) - experiment = getexperiment(mlf, exp.experiment_id) - @test experiment.experiment_id == exp.experiment_id - @test experiment.lifecycle_stage == "deleted" -end - -@testset "createrun" begin - @ensuremlf - expname = "createrun-$(UUIDs.uuid4())" - e = getorcreateexperiment(mlf, expname) - runname = "run-$(UUIDs.uuid4())" - r = createrun(mlf, e.experiment_id; run_name=runname) - - @test isa(r, MLFlowRun) - @test r.info.run_name == runname - deleteexperiment(mlf, e) -end - -@testset "deleterun" begin - @ensuremlf - expname = "deleterun-$(UUIDs.uuid4())" - e = getorcreateexperiment(mlf, expname) - r = createrun(mlf, e.experiment_id) - - @test deleterun(mlf, r) - deleteexperiment(mlf, e) -end - -@testset "updaterun" begin - @ensuremlf - expname = "updaterun-$(UUIDs.uuid4())" - e = getorcreateexperiment(mlf, expname) - runname = "run-$(UUIDs.uuid4())" - r = createrun(mlf, e.experiment_id; run_name=runname) - - new_runname = "new_updaterun-$(UUIDs.uuid4())" - new_status = "FINISHED" - r_updated = updaterun(mlf, r, new_status; run_name=new_runname) - - @test isa(r_updated, MLFlowRun) - @test r_updated.info.run_name != r.info.run_name - @test r_updated.info.status.status != r.info.status - @test r_updated.info.run_name == new_runname - @test r_updated.info.status.status == new_status - deleteexperiment(mlf, e) -end - -@testset "getorcreateexperiment" begin - @ensuremlf - expname = "getorcreate" - artifact_location = "test$(expname)" - e = getorcreateexperiment(mlf, expname; artifact_location=artifact_location) - @test isa(e, MLFlowExperiment) - ee = getorcreateexperiment(mlf, expname) - @test isa(ee, MLFlowExperiment) - @test e === ee - @test occursin(artifact_location, e.artifact_location) - @test deleteexperiment(mlf, ee) - @test deleteexperiment(mlf, ee) -end - -@testset "generatefilterfromparama" begin - filter_params = Dict("k1" => "v1") - filter = generatefilterfromparams(filter_params) - @test filter == "param.\"k1\" = \"v1\"" - filter_params = Dict("k1" => "v1", "started" => Date("2020-01-01")) - filter = generatefilterfromparams(filter_params) - @test occursin("param.\"k1\" = \"v1\"", filter) - @test occursin("param.\"started\" = \"2020-01-01\"", filter) - @test occursin(" and ", filter) -end - -@testset "searchruns" begin - @ensuremlf - exp = createexperiment(mlf) - expid = exp.experiment_id - exprun = createrun(mlf, exp) - @test exprun.info.experiment_id == expid - @test exprun.info.lifecycle_stage == "active" - @test exprun.info.status == MLFlowRunStatus("RUNNING") - exprunid = exprun.info.run_id - - runparams = Dict( - "k1" => "v1", - "started" => Date("2020-01-01") - ) - logparam(mlf, exprun, runparams) - - findrun = searchruns(mlf, exp; filter_params=runparams) - @test length(findrun) == 1 - r = only(findrun) - @test get_run_id(get_info(r)) == exprun.info.run_id - @test get_run_id(r) == get_run_id(get_info(r)) - @test sort(collect(keys(get_params(get_data(r))))) == sort(string.(keys(runparams))) - @test sort(collect(values(get_params(get_data(r))))) == sort(string.(values(runparams))) - @test get_params(r) == get_params(get_data(r)) - @test deleteexperiment(mlf, exp) -end - @testset "artifacts" begin @ensuremlf exp = createexperiment(mlf) @@ -223,67 +119,3 @@ end deleterun(mlf, exprun) deleteexperiment(mlf, exp) end - -@testset "MLFlowClient.jl" begin - @ensuremlf - exp = createexperiment(mlf) - @test isa(exp, MLFlowExperiment) - - exptags = [:key => "val"] - expname = "expname-$(UUIDs.uuid4())" - - @test ismissing(getexperiment(mlf, "$(UUIDs.uuid4()) - $(UUIDs.uuid4())")) - - experiment = createexperiment(mlf; name=expname, tags=exptags) - experiment_id = experiment.experiment_id - experimentbyname = getexperiment(mlf, expname) - @test experimentbyname.name == experiment.name - - exprun = createrun(mlf, experiment_id) - @test exprun.info.experiment_id == experiment_id - @test exprun.info.lifecycle_stage == "active" - @test exprun.info.status == MLFlowRunStatus("RUNNING") - exprunid = exprun.info.run_id - - logparam(mlf, exprunid, "paramkey", "paramval") - logparam(mlf, exprunid, Dict("k" => "v", "k1" => "v1")) - logparam(mlf, exprun, Dict("test1" => "test2")) - - logmetric(mlf, exprun, "metrickeyrun", 1.0) - logmetric(mlf, exprun.info, "metrickeyrun", 2.0) - logmetric(mlf, exprun.info, "metrickeyrun", [2.5, 3.5]) - logmetric(mlf, exprunid, "metrickey", 1.0) - logmetric(mlf, exprunid, "metrickey2", [1.0, 1.5, 2.0]) - - retrieved_run = getrun(mlf, exprunid) - @test exprun.info == retrieved_run.info - - running_run = updaterun(mlf, exprunid, "RUNNING") - @test running_run.info.experiment_id == experiment_id - @test running_run.info.status == MLFlowRunStatus("RUNNING") - finished_run = updaterun(mlf, exprun, MLFlowRunStatus("FINISHED")) - finishedrun = getrun(mlf, finished_run.info.run_id) - - @test !ismissing(finishedrun.info.end_time) - - exprun2 = createrun(mlf, experiment_id) - exprun2id = exprun.info.run_id - logparam(mlf, exprun2, "param2", "key2") - logmetric(mlf, exprun2, "metric2", [1.0, 2.0]) - updaterun(mlf, exprun2, "FINISHED") - - runs = searchruns(mlf, experiment_id) - @test length(runs) == 2 - runs = searchruns(mlf, experiment_id; filter="param.param2 = \"key2\"") - @test length(runs) == 1 - @test_throws ErrorException searchruns(mlf, experiment_id; run_view_type="ERR") - runs = searchruns(mlf, experiment_id; filter="param.param2 = \"key3\"") - @test length(runs) == 0 - runs = searchruns(mlf, experiment_id; max_results=1) # test paging functionality - @test length(runs) == 2 - deleterun(mlf, exprunid) - deleterun(mlf, exprun2) - - deleteexperiment(mlf, exp) - deleteexperiment(mlf, experiment) -end diff --git a/test/test_loggers.jl b/test/test_loggers.jl new file mode 100644 index 0000000..865818c --- /dev/null +++ b/test/test_loggers.jl @@ -0,0 +1,144 @@ +@testset verbose = true "logparam" begin + @ensuremlf + expname = "logparam-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id) + + @testset "logparam_by_run_id_and_key_value" begin + logparam(mlf, r.info.run_id, "run_id_key_value", "test") + retrieved_run = searchruns(mlf, e; filter_params=Dict("run_id_key_value" => "test")) + @test length(retrieved_run) == 1 + @test retrieved_run[1].info.run_id == r.info.run_id + end + + @testset "logparam_by_run_info_and_key_value" begin + logparam(mlf, r.info, "run_id_key_value", "test") + retrieved_run = searchruns(mlf, e; filter_params=Dict("run_id_key_value" => "test")) + @test length(retrieved_run) == 1 + @test retrieved_run[1].info.run_id == r.info.run_id + end + + @testset "logparam_by_run_and_key_value" begin + logparam(mlf, r, "run_id_key_value", "test") + retrieved_run = searchruns(mlf, e; filter_params=Dict("run_id_key_value" => "test")) + @test length(retrieved_run) == 1 + @test retrieved_run[1].info.run_id == r.info.run_id + end + + @testset "logparam_by_union_and_dict_key_value" begin + logparam(mlf, r, Dict("run_id_key_value" => "test")) + retrieved_run = searchruns(mlf, e; filter_params=Dict("run_id_key_value" => "test")) + @test length(retrieved_run) == 1 + @test retrieved_run[1].info.run_id == r.info.run_id + end + + deleteexperiment(mlf, e) +end + +@testset verbose = true "logmetric" begin + @ensuremlf + expname = "logmetric-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id) + + @testset "logmetric_by_run_id_and_key_value" begin + logmetric(mlf, r.info.run_id, "run_id_key_value", 1) + retrieved_run = searchruns(mlf, e) + @test length(retrieved_run) == 1 + @test isa(retrieved_run[1].data.metrics["run_id_key_value"], MLFlowRunDataMetric) + @test retrieved_run[1].data.metrics["run_id_key_value"].value == 1 + end + + @testset "logmetric_by_run_info_and_key_value" begin + logmetric(mlf, r.info, "run_id_key_value", 1) + retrieved_run = searchruns(mlf, e) + @test length(retrieved_run) == 1 + @test isa(retrieved_run[1].data.metrics["run_id_key_value"], MLFlowRunDataMetric) + @test retrieved_run[1].data.metrics["run_id_key_value"].value == 1 + end + + @testset "logmetric_by_run_and_key_value" begin + logmetric(mlf, r, "run_id_key_value", 1) + retrieved_run = searchruns(mlf, e) + @test length(retrieved_run) == 1 + @test isa(retrieved_run[1].data.metrics["run_id_key_value"], MLFlowRunDataMetric) + @test retrieved_run[1].data.metrics["run_id_key_value"].value == 1 + end + + @testset "logmetric_by_union_and_key_arrayvalue" begin + logmetric(mlf, r, "run_id_key_value", [1, 2, 3]) + retrieved_run = searchruns(mlf, e) + @test length(retrieved_run) == 1 + @test isa(retrieved_run[1].data.metrics["run_id_key_value"], MLFlowRunDataMetric) + @test retrieved_run[1].data.metrics["run_id_key_value"].value == 3 + end + + deleteexperiment(mlf, e) +end + +@testset verbose = true "logartifact" begin + @ensuremlf + expname = "logartifact-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname; artifact_location="/tmp/mlflow") + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id) + artifact_uri = r.info.artifact_uri + + tmpfile = "/tmp/mlflowclient-tempfile.txt" + open(tmpfile, "w") do f + write(f, "test") + end + + @testset "logartifact_by_run_and_filenameanddata" begin + artifact = logartifact(mlf, r, tmpfile, "testing") + @test isfile(artifact) + end + + @testset "logartifact_by_run_id_and_file" begin + artifact = logartifact(mlf, r.info.run_id, tmpfile) + @test isfile(artifact) + end + + @testset "logartifact_by_run_and_file" begin + artifact = logartifact(mlf, r, tmpfile) + @test isfile(artifact) + end + + @testset "logartifact_by_run_info_and_file" begin + artifact = logartifact(mlf, r.info, tmpfile) + @test isfile(artifact) + end + + @testset "logartifact_error" begin + @test_broken logartifact(mlf, r, "/etc/shadow") + end + + deleteexperiment(mlf, e) +end + +@testset verbose=true "listartifacts" begin + @ensuremlf + expname = "listartifacts-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id) + + @testset "listartifacts_by_run_id" begin + artifacts = listartifacts(mlf, r.info.run_id) + @test length(artifacts) == 0 + end + + @testset "listartifacts_by_run" begin + artifacts = listartifacts(mlf, r) + @test length(artifacts) == 0 + end + + @testset "listartifacts_by_run_info" begin + artifacts = listartifacts(mlf, r.info) + @test length(artifacts) == 0 + end + + deleteexperiment(mlf, e) +end diff --git a/test/test_runs.jl b/test/test_runs.jl new file mode 100644 index 0000000..06adae1 --- /dev/null +++ b/test/test_runs.jl @@ -0,0 +1,172 @@ +@testset verbose = true "createrun" begin + @ensuremlf + expname = "createrun-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + + function runtests(run) + @test isa(run, MLFlowRun) + @test run.info.run_name == runname + end + + @testset "createrun_by_experiment_id" begin + r = createrun(mlf, e.experiment_id; run_name=runname) + runtests(r) + end + + @testset "createrun_by_experiment_type" begin + r = createrun(mlf, e; run_name=runname) + runtests(r) + end + + deleteexperiment(mlf, e) +end + +@testset "getrun" begin + @ensuremlf + expname = "getrun-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id; run_name=runname) + + retrieved_r = getrun(mlf, r.info.run_id) + + @test isa(retrieved_r, MLFlowRun) + @test retrieved_r.info.run_id == r.info.run_id + deleteexperiment(mlf, e) +end + +@testset verbose = true "updaterun" begin + @ensuremlf + expname = "updaterun-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + runname = "run-$(UUIDs.uuid4())" + r = createrun(mlf, e.experiment_id; run_name=runname) + + new_runname = "new_updaterun-$(UUIDs.uuid4())" + new_status = "FINISHED" + new_status_using_type = MLFlowRunStatus("FINISHED") + + function runtests(run_updated) + @test isa(run_updated, MLFlowRun) + @test run_updated.info.run_name != r.info.run_name + @test run_updated.info.status.status != r.info.status + @test run_updated.info.run_name == new_runname + @test run_updated.info.status.status == new_status + end + + @testset "updaterun_by_run_id" begin + r_updated = updaterun(mlf, r.info.run_id, new_status; run_name=new_runname) + runtests(r_updated) + end + @testset "updaterun_by_run_info" begin + r_updated = updaterun(mlf, r.info, new_status; run_name=new_runname) + runtests(r_updated) + end + @testset "updaterun_byrun" begin + r_updated = updaterun(mlf, r, new_status; run_name=new_runname) + runtests(r_updated) + end + + @testset "updaterun_by_run_info_and_defined_status" begin + r_updated = updaterun(mlf, r.info, new_status_using_type; run_name=new_runname) + runtests(r_updated) + end + @testset "updaterun_by_run_and_defined_status" begin + r_updated = updaterun(mlf, r, new_status_using_type; run_name=new_runname) + runtests(r_updated) + end + + deleteexperiment(mlf, e) +end + +@testset verbose = true "deleterun" begin + @ensuremlf + expname = "deleterun-$(UUIDs.uuid4())" + e = getorcreateexperiment(mlf, expname) + + function runtests(run) + @test deleterun(mlf, run) + end + + @testset "deleterun_by_run_info" begin + r = createrun(mlf, e.experiment_id) + runtests(r.info) + end + + @testset "deleterun_by_run" begin + r = createrun(mlf, e.experiment_id) + runtests(r) + end + + deleteexperiment(mlf, e) +end + +@testset verbose = true "searchruns" begin + @ensuremlf + getexpname() = "searchruns-$(UUIDs.uuid4())" + e1 = getorcreateexperiment(mlf, getexpname()) + e2 = getorcreateexperiment(mlf, getexpname()) + + run_array1 = MLFlowRun[] + run_array2 = MLFlowRun[] + run_status = ["FINISHED", "FINISHED", "FAILED"] + failed_runs = 0 + + function addruns!(run_array, experiment, run_status) + for status in run_status + run = createrun(mlf, experiment.experiment_id) + run = updaterun(mlf, run, status) + if status == "FAILED" + logparam(mlf, run, "test", "failed") + failed_runs += 1 + else + logparam(mlf, run, "test", "test") + end + push!(run_array, run) + end + end + + addruns!(run_array1, e1, run_status) + addruns!(run_array2, e2, run_status) + + @testset "searchruns_by_experiment_id" begin + runs = searchruns(mlf, e1.experiment_id) + @test runs |> length == run_array1 |> length + end + + @testset "searchruns_by_experiment" begin + runs = searchruns(mlf, e1) + @test runs |> length == run_array1 |> length + end + + @testset "searchruns_by_experiments_array" begin + runs = searchruns(mlf, [e1, e2]) + @test runs |> length == (run_array1 |> length) + (run_array2 |> length) + end + + @testset "searchruns_by_filter" begin + runs = searchruns(mlf, [e1, e2]; filter="param.test = \"failed\"") + @test failed_runs == runs |> length + end + + @testset "searchruns_by_filter_params" begin + runs = searchruns(mlf, [e1, e2]; filter_params=Dict("test" => "failed")) + @test failed_runs == runs |> length + end + + @testset "searchruns_filter_exception" begin + @test_throws ErrorException searchruns(mlf, [e1, e2]; filter="test", filter_params=Dict("test" => "test")) + end + + @testset "runs_get_methods" begin + runs = searchruns(mlf, [e1, e2]; filter_params=Dict("test" => "failed")) + @test get_info(runs[1]) == runs[1].info + @test get_data(runs[1]) == runs[1].data + @test get_run_id(runs[1]) == runs[1].info.run_id + @test get_params(runs[1]) == runs[1].data.params + end + + deleteexperiment(mlf, e1) + deleteexperiment(mlf, e2) +end