From 9edd1ec7eea60a3354a1b9dbe4356ff8f83ffa5c Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 8 Apr 2021 12:21:14 +1200 Subject: [PATCH] add warning for use of shuffling in holdout --- src/constructors.jl | 9 +++++++-- test/constructors.jl | 8 +++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index 6315227..c9c0d26 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -176,8 +176,6 @@ function IteratedModel(; model=nothing, model == nothing && throw(ERR_NO_MODEL) - - if model isa Deterministic iterated_model = DeterministicIteratedModel(model, controls, @@ -229,5 +227,12 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel) iteration_parameter(iterated_model.model) === nothing && throw(ERR_NEED_PARAMETER) + if iterated_model.resampling isa Holdout && + iterated_model.resampling.shuffle + message *= "The use of sample-shuffling in `Holdout` "* + "will significantly slow training as "* + "each increment of the iteration parameter "* + "will force iteration from scratch (cold restart). " + end return message end diff --git a/test/constructors.jl b/test/constructors.jl index 7a1a009..c17ac8e 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -21,7 +21,13 @@ struct Bar <: MLJBase.Deterministic end @test iterated_model.measure == RootMeanSquaredError() @test_logs IteratedModel(model=model, measure=mae) - iterated_model = @test_logs IteratedModel(model=model, resampling=nothing) + @test_logs IteratedModel(model=model, resampling=nothing) + + @test_logs((:info, r"The use of sample"), + IteratedModel(model=model, + resampling=Holdout(rng=123), + measure=rms)) + end end