From 9edd1ec7eea60a3354a1b9dbe4356ff8f83ffa5c Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 8 Apr 2021 12:21:14 +1200 Subject: [PATCH 1/4] add warning for use of shuffling in holdout --- src/constructors.jl | 9 +++++++-- test/constructors.jl | 8 +++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index 6315227..c9c0d26 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -176,8 +176,6 @@ function IteratedModel(; model=nothing, model == nothing && throw(ERR_NO_MODEL) - - if model isa Deterministic iterated_model = DeterministicIteratedModel(model, controls, @@ -229,5 +227,12 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel) iteration_parameter(iterated_model.model) === nothing && throw(ERR_NEED_PARAMETER) + if iterated_model.resampling isa Holdout && + iterated_model.resampling.shuffle + message *= "The use of sample-shuffling in `Holdout` "* + "will significantly slow training as "* + "each increment of the iteration parameter "* + "will force iteration from scratch (cold restart). " + end return message end diff --git a/test/constructors.jl b/test/constructors.jl index 7a1a009..c17ac8e 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -21,7 +21,13 @@ struct Bar <: MLJBase.Deterministic end @test iterated_model.measure == RootMeanSquaredError() @test_logs IteratedModel(model=model, measure=mae) - iterated_model = @test_logs IteratedModel(model=model, resampling=nothing) + @test_logs IteratedModel(model=model, resampling=nothing) + + @test_logs((:info, r"The use of sample"), + IteratedModel(model=model, + resampling=Holdout(rng=123), + measure=rms)) + end end From 9cb64bf76b18e6e8431a35bda3ad5a9f18b59c1b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 8 Apr 2021 13:58:45 +1200 Subject: [PATCH 2/4] address #11; rm CategoricalArrays as explicit [extras] --- Project.toml | 3 +-- src/MLJIteration.jl | 4 ++++ src/constructors.jl | 14 ++++++++++++++ test/_dummy_model.jl | 2 +- test/constructors.jl | 4 ++++ 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index cbd7c6d..51a33b7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,11 +14,10 @@ MLJBase = "0.17.7, 0.18" julia = "1" [extras] -CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CategoricalArrays", "MLJModelInterface", "StableRNGs", "Statistics", "Test"] +test = ["MLJModelInterface", "StableRNGs", "Statistics", "Test"] diff --git a/src/MLJIteration.jl b/src/MLJIteration.jl index 72272ea..6d1d4f4 100644 --- a/src/MLJIteration.jl +++ b/src/MLJIteration.jl @@ -12,6 +12,8 @@ const CONTROLS = vcat(IterationControl.CONTROLS, :WithEvaluationDo, :CycleLearningRate]) +const TRAINING_CONTROLS = [:Step, ] + # export all control types: for control in CONTROLS eval(:(export $control)) @@ -30,5 +32,7 @@ include("ic_model.jl") include("controls.jl") include("core.jl") +const Control = Union{[@eval($c) for c in CONTROLS]...} +const TrainingControl = Union{[@eval($c) for c in TRAINING_CONTROLS]...} end # module diff --git a/src/constructors.jl b/src/constructors.jl index c9c0d26..77540e2 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -1,3 +1,9 @@ +const ERR_MISSING_TRAINING_CONTROL = + ArgumentError("At least one control must be a training control "* + "(ie, be on this list: $TRAINING_CONTROLS) or be a "* + "custom control that calls IterationControl.train!. ") + + ## TYPES AND CONSTRUCTOR mutable struct DeterministicIteratedModel{M<:Deterministic} <: MLJBase.Deterministic @@ -234,5 +240,13 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel) "each increment of the iteration parameter "* "will force iteration from scratch (cold restart). " end + + training_control_candidates = filter(iterated_model.controls) do c + c isa TrainingControl || !(c isa Control) + end + if isempty(training_control_candidates) + throw(ERR_MISSING_TRAINING_CONTROL) + end + return message end diff --git a/test/_dummy_model.jl b/test/_dummy_model.jl index 49129ab..c2b5ddf 100644 --- a/test/_dummy_model.jl +++ b/test/_dummy_model.jl @@ -5,7 +5,7 @@ export DummyIterativeModel, make_dummy using Random using Statistics import StableRNGs.LehmerRNG -using CategoricalArrays +using MLJBase.CategoricalArrays import Base.== using MLJModelInterface diff --git a/test/constructors.jl b/test/constructors.jl index c17ac8e..242e2cd 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -28,6 +28,10 @@ struct Bar <: MLJBase.Deterministic end resampling=Holdout(rng=123), measure=rms)) + @test_throws(MLJIteration.ERR_MISSING_TRAINING_CONTROL, + IteratedModel(model=model, + resampling=nothing, + controls=[Patience(), NotANumber()])) end end From 79868364b762ae6db02fab5e908146543942d4ac Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 8 Apr 2021 17:26:21 +1200 Subject: [PATCH 3/4] minor --- src/MLJIteration.jl | 9 ++++++--- src/constructors.jl | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/MLJIteration.jl b/src/MLJIteration.jl index 6d1d4f4..58791ef 100644 --- a/src/MLJIteration.jl +++ b/src/MLJIteration.jl @@ -26,13 +26,16 @@ const CONTROLS_DEFAULT = [Step(10), NotANumber()] include("utilities.jl") +include("controls.jl") + +const Control = Union{[@eval($c) for c in CONTROLS]...} +const TrainingControl = Union{[@eval($c) for c in TRAINING_CONTROLS]...} + include("constructors.jl") include("traits.jl") include("ic_model.jl") -include("controls.jl") include("core.jl") -const Control = Union{[@eval($c) for c in CONTROLS]...} -const TrainingControl = Union{[@eval($c) for c in TRAINING_CONTROLS]...} + end # module diff --git a/src/constructors.jl b/src/constructors.jl index 77540e2..96d8044 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -1,6 +1,6 @@ const ERR_MISSING_TRAINING_CONTROL = ArgumentError("At least one control must be a training control "* - "(ie, be on this list: $TRAINING_CONTROLS) or be a "* + "(have type `$TrainingControl`) or be a "* "custom control that calls IterationControl.train!. ") From fdfe798941e708ae9d5f3e9c0f8ca51d55ae43b9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 8 Apr 2021 17:45:00 +1200 Subject: [PATCH 4/4] bump 0.2.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 51a33b7..62fe738 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.2.1" +version = "0.2.2" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"