Skip to content

Commit

Permalink
Strides fix in broadcast (#504)
Browse files Browse the repository at this point in the history
* ignore oftype

* broadcast fixes

* delete problematic line that accidentally wasn't removed

* no print

* fix order
  • Loading branch information
chriselrod authored Jul 11, 2023
1 parent befd727 commit a21d6f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.162"
version = "0.12.163"


[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
54 changes: 28 additions & 26 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
:(ArrayInterface.contiguous_axis),
:(ArrayInterface.contiguous_batch_size),
:(ArrayInterface.device),
:(ArrayInterface.dense_dims),
:(ArrayInterface.stride_rank),
:(VectorizationBase.val_dense_dims),
:(ArrayInterface.offsets),
Expand All @@ -204,7 +205,9 @@ function is_column_major(x)
true
end
is_row_major(x) = is_column_major(reverse(x))
# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
_find_arg_least_greater(r::Vector{Int}, i) =
findmin(x -> x > i ? x : typemax(Int), r)

function _strides_expr(
@nospecialize(s),
@nospecialize(x),
Expand All @@ -214,20 +217,19 @@ function _strides_expr(
N = length(R)
q = Expr(:block, Expr(:meta, :inline))
strd_tup = Expr(:tuple)
resize!(strd_tup.args, N)
ifel = GlobalRef(Core, :ifelse)
Nrange = 1:1:N # type stability w/ respect to reverse
Nrange = 1:N # type stability w/ respect to reverse
# Nrange = 1:1:N # type stability w/ respect to reverse
use_stride_acc = true
stride_acc::Int = 1
if is_column_major(R)
# elseif is_row_major(R)
# Nrange = reverse(Nrange)
else # not worth my time optimizing this case at the moment...
# will write something generic stride-rank agnostic eventually
next, n = _find_arg_least_greater(R, 0)
if !D[n]
use_stride_acc = false
stride_acc = 0
end
sₙ_value::Int = 0
for n Nrange
for _n Nrange
xₙ_type = x[n]
xₙ_static = xₙ_type <: StaticInt
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
Expand All @@ -236,38 +238,38 @@ function _strides_expr(
if sₙ_static
sₙ_value = s_type.parameters[1]
if s_type === One
push!(strd_tup.args, Expr(:call, lv(:Zero)))
strd_tup.args[n] = Expr(:call, lv(:Zero))
elseif stride_acc 0
push!(strd_tup.args, staticexpr(stride_acc))
strd_tup.args[n] = staticexpr(stride_acc)
else
push!(strd_tup.args, :($getfield(x, $n)))
strd_tup.args[n] = :($getfield(x, $n))
end
else
if xₙ_static
push!(strd_tup.args, staticexpr(xₙ_value))
strd_tup.args[n] = staticexpr(xₙ_value)
elseif stride_acc 0
push!(strd_tup.args, staticexpr(stride_acc))
strd_tup.args[n] = staticexpr(stride_acc)
else
push!(
strd_tup.args,
strd_tup.args[n] =
:($ifel(isone($getfield(s, $n)), zero($xₙ_type), $getfield(x, $n)))
)
end
end
if (n last(Nrange)) && use_stride_acc
nnext = n + step(Nrange)
if D[nnext]
if xₙ_static & sₙ_static
stride_acc = xₙ_value * sₙ_value
elseif sₙ_static
if stride_acc 0
stride_acc *= sₙ_value
if (_n N)
next, n = _find_arg_least_greater(R, next)
if use_stride_acc
if D[n]
if xₙ_static & sₙ_static
stride_acc = xₙ_value * sₙ_value
elseif sₙ_static
if stride_acc 0
stride_acc *= sₙ_value
end
else
stride_acc = 0
end
else
stride_acc = 0
end
else
stride_acc = 0
end
end
end
Expand Down

2 comments on commit a21d6f8

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87224

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.163 -m "<description of version>" a21d6f87f07df6062a0dc6f775034600a5fa8331
git push origin v0.12.163

Please sign in to comment.