diff --git a/ext/UnitfulExt.jl b/ext/UnitfulExt.jl index a709bd878..dcb73f3c7 100644 --- a/ext/UnitfulExt.jl +++ b/ext/UnitfulExt.jl @@ -5,9 +5,9 @@ module UnitfulExt import Plots: Plots, @ext_imp_use, @recipe, PlotText, Subplot, AVec, AMat, Axis import RecipesBase -@ext_imp_use :import Unitful Quantity unit ustrip Unitful dimension Units NoUnits +@ext_imp_use :import Unitful Quantity unit ustrip Unitful dimension Units NoUnits LogScaled logunit MixedUnits Level Gain uconvert -const MissingOrQuantity = Union{Missing,<:Quantity} +const MissingOrQuantity = Union{Missing,<:Quantity,<:LogScaled} #========== Main recipe @@ -17,7 +17,7 @@ Main recipe axisletter = plotattributes[:letter] # x, y, or z clims_types = (:contour, :contourf, :heatmap, :surface) if axisletter === :z && get(plotattributes, :seriestype, :nothing) ∈ clims_types - u = get(plotattributes, :zunit, unit(eltype(x))) + u = get(plotattributes, :zunit, _unit(eltype(x))) ustripattribute!(plotattributes, :clims, u) append_unit_if_needed!(plotattributes, :colorbar_title, u) end @@ -33,7 +33,7 @@ function fixaxis!(attr, x, axisletter) axisunit = Symbol(axisletter, :unit) # xunit, yunit, zunit axis = Symbol(axisletter, :axis) # xaxis, yaxis, zaxis # Get the unit - u = pop!(attr, axisunit, unit(eltype(x))) + u = pop!(attr, axisunit, _unit(eltype(x))) # If the subplot already exists with data, get its unit sp = get(attr, :subplot, 1) if sp ≤ length(attr[:plot_object]) && attr[:plot_object].n > 0 @@ -54,12 +54,12 @@ function fixaxis!(attr, x, axisletter) fixmarkersize!(attr) fixlinecolor!(attr) # Strip the unit - ustrip.(u, x) + _ustrip.(u, x) end # Recipe for (x::AVec, y::AVec, z::Surface) types @recipe function f(x::AVec, y::AVec, z::AMat{T}) where {T<:Quantity} # COV_EXCL_LINE - u = get(plotattributes, :zunit, unit(eltype(z))) + u = get(plotattributes, :zunit, _unit(eltype(z))) ustripattribute!(plotattributes, :clims, u) z = fixaxis!(plotattributes, z, :z) append_unit_if_needed!(plotattributes, :colorbar_title, u) @@ -159,8 +159,8 @@ fixlinecolor!(attr) = ustripattribute!(attr, :line_z) ustripattribute!(attr, key) = if haskey(attr, key) v = attr[key] - u = unit(eltype(v)) - attr[key] = ustrip.(u, v) + u = _unit(eltype(v)) + attr[key] = _ustrip.(u, v) return u else return NoUnits @@ -170,7 +170,7 @@ function ustripattribute!(attr, key, u) if haskey(attr, key) v = attr[key] if eltype(v) <: Quantity - attr[key] = ustrip.(u, v) + attr[key] = _ustrip.(u, v) end end u @@ -204,7 +204,7 @@ Plots.protectedstring(s) = ProtectedString(s) Append unit to labels when appropriate =====================================# -append_unit_if_needed!(attr, key, u::Units) = +append_unit_if_needed!(attr, key, u) = append_unit_if_needed!(attr, key, get(attr, key, nothing), u) # dispatch on the type of `label` append_unit_if_needed!(attr, key, label::ProtectedString, u) = nothing @@ -257,25 +257,36 @@ Plots.locate_annotation( x::MissingOrQuantity, y::MissingOrQuantity, label::PlotText, -) = (ustrip(x), ustrip(y), label) +) = (_ustrip(x), _ustrip(y), label) Plots.locate_annotation( sp::Subplot, x::MissingOrQuantity, y::MissingOrQuantity, z::MissingOrQuantity, label::PlotText, -) = (ustrip(x), ustrip(y), ustrip(z), label) +) = (_ustrip(x), _ustrip(y), _ustrip(z), label) Plots.locate_annotation(sp::Subplot, rel::NTuple{N,<:MissingOrQuantity}, label) where {N} = - Plots.locate_annotation(sp, ustrip.(rel), label) + Plots.locate_annotation(sp, _ustrip.(rel), label) #==================# # ticks and limits # #==================# Plots._transform_ticks(ticks::AbstractArray{T}, axis) where {T<:Quantity} = - ustrip.(getaxisunit(axis), ticks) + _ustrip.(getaxisunit(axis), ticks) Plots.process_limits(lims::AbstractArray{T}, axis) where {T<:Quantity} = - ustrip.(getaxisunit(axis), lims) + _ustrip.(getaxisunit(axis), lims) Plots.process_limits(lims::Tuple{S,T}, axis) where {S<:Quantity,T<:Quantity} = - ustrip.(getaxisunit(axis), lims) + _ustrip.(getaxisunit(axis), lims) + +function _ustrip(u, x) + u isa MixedUnits && return ustrip(uconvert(u, x)) + return ustrip(u, x) +end + +function _unit(x) + t = eltype(x) + t <: LogScaled && return logunit(t) + return unit(x) +end end # module diff --git a/test/test_unitful.jl b/test/test_unitful.jl index bc6255140..01665bf41 100644 --- a/test/test_unitful.jl +++ b/test/test_unitful.jl @@ -377,3 +377,18 @@ end @test ncodeunits(str) == 4 @test codeunit(str) == UInt8 end + +@testset "Logunits plots" begin + u = (1:3)u"B" + v = (1:3)u"dB" + x = (1:3)u"dBV" + y = (1:3)u"V" + pl = plot(u, x) + @test pl isa Plot + @test xguide(pl) == "B" + @test yguide(pl) == "dBV" + @test plot!(pl, v, y) isa Plot + pl = plot(v, y) + @test pl isa Plot + @test plot!(pl, u, x) isa Plot +end