Skip to content

Commit

Permalink
Change ReConstructor to Reconstructor
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Apr 11, 2022
1 parent 0b633a9 commit c98257f
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 91 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.2.0"
version = "0.2.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
40 changes: 20 additions & 20 deletions src/Core/flatten/construct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,37 +116,37 @@ Contains information for flatten/unflatten construct.
# Fields
$(TYPEDFIELDS)
"""
struct ReConstructor{F<:FlattenDefault, S<:FlattenConstructor, T<:UnflattenConstructor}
struct Reconstructor{F<:FlattenDefault, S<:FlattenConstructor, T<:UnflattenConstructor}
default::F
flatten::S
unflatten::T
function ReConstructor(default::F, flatten::S, unflatten::T) where {F<:FlattenDefault, S<:FlattenConstructor, T<:UnflattenConstructor}
function Reconstructor(default::F, flatten::S, unflatten::T) where {F<:FlattenDefault, S<:FlattenConstructor, T<:UnflattenConstructor}
return new{F,S,T}(default, flatten, unflatten)
end
end
function ReConstructor(flattendefault::FlattenDefault, x)
function Reconstructor(flattendefault::FlattenDefault, x)
# Assign flatten constructors
flatten, unflatten = construct_flatten(flattendefault, UnflattenStrict(), x)
flattenAD, unflattenAD = construct_flatten(flattendefault, UnflattenFlexible(), x)
flatten_constructor = FlattenConstructor(flatten, flattenAD)
unflatten_constructor = UnflattenConstructor(unflatten, unflattenAD)
# Return structs
return ReConstructor(flattendefault, flatten_constructor, unflatten_constructor)
return Reconstructor(flattendefault, flatten_constructor, unflatten_constructor)
end
function ReConstructor(flattendefault::FlattenDefault, constraint, x)
function Reconstructor(flattendefault::FlattenDefault, constraint, x)
# Assign flatten constructors
flatten, unflatten = construct_flatten(flattendefault, UnflattenStrict(), constraint, x)
flattenAD, unflattenAD = construct_flatten(flattendefault, UnflattenFlexible(), constraint, x)
flatten_constructor = FlattenConstructor(flatten, flattenAD)
unflatten_constructor = UnflattenConstructor(unflatten, unflattenAD)
# Return structs
return ReConstructor(flattendefault, flatten_constructor, unflatten_constructor)
return Reconstructor(flattendefault, flatten_constructor, unflatten_constructor)
end
function ReConstructor(x)
return ReConstructor(FlattenDefault(), x)
function Reconstructor(x)
return Reconstructor(FlattenDefault(), x)
end
function ReConstructor(constraint, x)
return ReConstructor(FlattenDefault(), constraint, x)
function Reconstructor(constraint, x)
return Reconstructor(FlattenDefault(), constraint, x)
end

############################################################################################
Expand All @@ -159,7 +159,7 @@ Convert 'x' into a Vector.
```
"""
function flatten end
function flatten(constructor::ReConstructor, x)
function flatten(constructor::Reconstructor, x)
return constructor.flatten.strict(x)
end

Expand All @@ -172,7 +172,7 @@ Convert 'x' into a Vector that is AD compatible.
```
"""
function flattenAD end
function flattenAD(constructor::ReConstructor, x)
function flattenAD(constructor::Reconstructor, x)
return constructor.flatten.flexible(x)
end

Expand All @@ -185,7 +185,7 @@ Unflatten 'x' into original shape.
```
"""
function unflatten end
function unflatten(constructor::ReConstructor, x)
function unflatten(constructor::Reconstructor, x)
return constructor.unflatten.strict(x)
end

Expand All @@ -198,7 +198,7 @@ Unflatten 'x' into original shape but keep type information of 'x' for AD compat
```
"""
function unflattenAD end
function unflattenAD(constructor::ReConstructor, x)
function unflattenAD(constructor::Reconstructor, x)
return constructor.unflatten.flexible(x)
end

Expand All @@ -212,20 +212,20 @@ Contains information to constrain and unconstrain parameter.
```julia
```
"""
struct TransformConstructor{S, T}
struct Transformconstructor{S, T}
constrain::S
unconstrain::T
function TransformConstructor(constraint, x)
function Transformconstructor(constraint, x)
#!NOTE: Transform is used to unconstrain, and inverse-transform to constrain parameter back.
transform, inverse_transform = construct_transform(constraint, x)
return new{typeof(inverse_transform), typeof(transform)}(inverse_transform, transform)
end
end

function constrain(transform::T, val) where {T<:TransformConstructor}
function constrain(transform::T, val) where {T<:Transformconstructor}
return constrain(transform.constrain, val)
end
function unconstrain(transform::T, val) where {T<:TransformConstructor}
function unconstrain(transform::T, val) where {T<:Transformconstructor}
return unconstrain(transform.unconstrain, val)
end

Expand All @@ -240,9 +240,9 @@ export FlattenTypes,
FlattenDefault,
FlattenConstructor,
UnflattenConstructor,
ReConstructor,
Reconstructor,
flatten,
flattenAD,
unflatten,
unflattenAD,
TransformConstructor
Transformconstructor
6 changes: 3 additions & 3 deletions src/Models/parameterinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Contains information about parameter distributions, transformations and constrai
# Fields
$(TYPEDFIELDS)
"""
struct ParameterInfo{C,R<:ReConstructor,T<:TransformConstructor}
struct ParameterInfo{C,R<:Reconstructor,T<:Transformconstructor}
"Constraint distribution/boundaries for all model parameter."
constraint::C
"Contains information for flatten/unflatten parameter"
Expand All @@ -17,9 +17,9 @@ struct ParameterInfo{C,R<:ReConstructor,T<:TransformConstructor}
constraint::C, val::B, flattendefault::D
) where {C<:NamedTuple,B<:NamedTuple,D<:FlattenDefault}
## Create flatten constructor
constructor = ReConstructor(flattendefault, constraint, val)
constructor = Reconstructor(flattendefault, constraint, val)
## Assign transformer constraint NamedTuple
transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
## Return ParameterInfo
return new{C,typeof(constructor),typeof(transformer)}(
constraint, constructor, transformer
Expand Down
4 changes: 2 additions & 2 deletions test/test-flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
## If applicable, check if gradients for supported AD frameworks can be computed
if length(θ_flat) > 0
function check_AD_closure(constraint, val)
reconstruct = ReConstructor(constraint, val)
reconstruct = Reconstructor(constraint, val)
bij, bij⁻¹ = construct_transform(constraint, val)
function check_AD(θₜ::AbstractVector{T}) where {T<:Real}
θ = unflattenAD(reconstruct, θₜ)
Expand Down Expand Up @@ -136,7 +136,7 @@ end
## If applicable, check if gradients for supported AD frameworks can be computed
if length(θ_flat) > 0
function check_AD_closure(constraint, val)
reconstruct = ReConstructor(constraint, val)
reconstruct = Reconstructor(constraint, val)
bij, bij⁻¹ = construct_transform(constraint, val)
function check_AD(θₜ::AbstractVector{T}) where {T<:Real}
θ = unflattenAD(reconstruct, θₜ)
Expand Down
48 changes: 24 additions & 24 deletions test/test-flatten/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = 2.
constraint = Bijectors.bijector(Gamma(2,2))
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand All @@ -32,7 +32,7 @@
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -60,8 +60,8 @@ end
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = 2.
constraint = Gamma(2,2)
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand All @@ -84,7 +84,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -112,8 +112,8 @@ end
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = 2.
constraint = Constrained(1.,3.)
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand All @@ -136,7 +136,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -164,8 +164,8 @@ end
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = 2.
constraint = Unconstrained()
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand All @@ -188,7 +188,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -216,8 +216,8 @@ end
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = zeros(2,3,4)
constraint = Fixed()
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand All @@ -240,7 +240,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -274,8 +274,8 @@ end
val[3,2] = val[2, 3] = 0.14
val
constraint = CorrelationMatrix()
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand Down Expand Up @@ -312,7 +312,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -349,8 +349,8 @@ end
val[3,2] = val[2, 3] = 0.14
val
constraint = CovarianceMatrix()
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand Down Expand Up @@ -387,7 +387,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down Expand Up @@ -422,8 +422,8 @@ end
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = [.1, .2, .7]
constraint = Simplex(val)
ReConstructor(constraint, val)
reconstruct = ReConstructor(flatdefault, constraint, val)
Reconstructor(constraint, val)
reconstruct = Reconstructor(flatdefault, constraint, val)
# Flatten
x_flat = flatten(reconstruct, val)
@test x_flat isa AbstractVector
Expand Down Expand Up @@ -460,7 +460,7 @@ end
x_transformed = _transform(val)
x_inversetransformed = _inversetransform(x_transformed)

transformer = TransformConstructor(constraint, val)
transformer = Transformconstructor(constraint, val)
val_uncon = unconstrain(transformer, val)
val_con = constrain(transformer, val_uncon)

Expand Down
2 changes: 1 addition & 1 deletion test/test-flatten/flatten.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
############################################################################################
function check_AD_closure(constraint, val)
reconstruct = ReConstructor(constraint, val)
reconstruct = Reconstructor(constraint, val)
bij, bij⁻¹ = construct_transform(constraint, val)
function check_AD(θₜ::AbstractVector{T}) where {T<:Real}
θ = unflattenAD(reconstruct, θₜ)
Expand Down
Loading

2 comments on commit c98257f

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58332

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" c98257fc0c3a81c2f8998f87f274224643d35148
git push origin v0.2.1

Please sign in to comment.