Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unprotected mode for specials with FullRank schemas as well #274

Merged
merged 7 commits into from
Mar 13, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 66 additions & 59 deletions src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,65 +375,6 @@ function apply_schema(t::FunctionTerm{typeof(unprotect)}, schema::Schema, Ctx::P
apply_schema(tt, schema, unprotect(Ctx))
end

"""
StatsModels.@support_unprotect f

Generate methods necessary for function `f` to support [`unprotect`](@ref)
inside of a `@formula`.

Any function call that occurs as a child of a protected call is also protected
by default. In order to support _unprotecting_ functions/operators that work
directly on `Term`s (like the built-in "special" operators `+`, `&`, `*`, and
`~`), we need to add methods for `apply_schema(::FunctionTerm{typeof(f)}, ...)`
that calls `f` on the captured arguments.

This macro generates the necessary method for `f`. For this to do something
reasonable, a few conditions must be met:

1. Methods must exist for `f(args::AbstractTerm...)` matching the specific
signatures that users provide when calling `f` in `@formula` (and usually,
returns an `AbstractTerm` of some kind).

2. The custom term type returned by `new_term = f(args::AbstractTerm...)` needs
to do something reasonable when `modelcols` is called on it.

3. The thing returned by `modelcols(new_term, data)` needs to be something that
can be consumed as input to whatever the parent call was for `f` in the
original formula expression.

To take a concrete example, if we have a function `g` that can do something
meaningful with the output of `modelcols(::InteractionTerm, ...)`, then when a
user provides something like

@formula(g(unprotect(a & b)))

that gets lowered to

FunctionTerm(g, [FuntionTerm(&, [Term(:a), Term(:b)], ...)], ...)

and we need to convert it to something like

FuntionTerm(g, [Term(:a) & Term(:b)], ...)

during schema application, which is what the method generated by `@support_unprotect &`
does.
"""
macro support_unprotect(op)
ex = quote
function apply_schema(t::StatsModels.FunctionTerm{typeof($op)},
sch::StatsModels.Schema,
Mod::Type)
apply_schema(t.f(t.args...), sch, Mod)
end
end
return esc(ex)
end

for op in SPECIALS
@eval @support_unprotect $op
end


"""
has_schema(t::T) where {T<:AbstractTerm}

Expand Down Expand Up @@ -595,3 +536,69 @@ termvars(t::TupleTerm) = mapreduce(termvars, union, t, init=Symbol[])
termvars(t::MatrixTerm) = termvars(t.terms)
termvars(t::FormulaTerm) = union(termvars(t.lhs), termvars(t.rhs))
termvars(t::FunctionTerm) = mapreduce(termvars, union, t.args, init=Symbol[])


"""
StatsModels.@support_unprotect f sch_types...

Generate methods necessary for function `f` to support [`unprotect`](@ref)
inside of a `@formula` with a schema of types `sch_types`. If not specified,
`sch_types` defaults to `(Schema, FullRank)` (the two schema types defined in
kleinschmidt marked this conversation as resolved.
Show resolved Hide resolved
StatsModels itself).

Any function call that occurs as a child of a protected call is also protected
by default. In order to support _unprotecting_ functions/operators that work
directly on `Term`s (like the built-in "special" operators `+`, `&`, `*`, and
`~`), we need to add methods for `apply_schema(::FunctionTerm{typeof(f)}, ...)`
that calls `f` on the captured arguments before further schema application.

This macro generates the necessary method for `f`. For this to do something
reasonable, a few conditions must be met:

1. Methods must exist for `f(args::AbstractTerm...)` matching the specific
signatures that users provide when calling `f` in `@formula` (and usually,
returns an `AbstractTerm` of some kind).

2. The custom term type returned by `new_term = f(args::AbstractTerm...)` needs
to do something reasonable when `modelcols` is called on it.

3. The thing returned by `modelcols(new_term, data)` needs to be something that
can be consumed as input to whatever the parent call was for `f` in the
original formula expression.

To take a concrete example, if we have a function `g` that can do something
meaningful with the output of `modelcols(::InteractionTerm, ...)`, then when a
user provides something like

@formula(g(unprotect(a & b)))

that gets lowered to

FunctionTerm(g, [FuntionTerm(&, [Term(:a), Term(:b)], ...)], ...)

and we need to convert it to something like

FuntionTerm(g, [Term(:a) & Term(:b)], ...)

during schema application, which is what the method generated by
`@support_unprotect &` does.
"""
macro support_unprotect(op, sch_types...)
sch_types = isempty(sch_types) ? (Schema, FullRank) : sch_types
ex = quote end
for sch_type in sch_types
sub_ex = quote
function StatsModels.apply_schema(t::StatsModels.FunctionTerm{typeof($op)},
sch::$sch_type,
Mod::Type)
apply_schema(t.f(t.args...), sch, Mod)
end
end
push!(ex.args, sub_ex)
end
return esc(ex)
end

for op in SPECIALS
@eval @support_unprotect $op
end