Skip to content

Commit

Permalink
more supporting functions; fixes permindex for the new KnnResult
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Feb 22, 2023
1 parent bb60067 commit 2b36054
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimilaritySearch"
uuid = "053f045d-5466-53fd-b400-a066f88fe02a"
authors = ["Eric S. Tellez <[email protected]>"]
version = "0.10.2"
version = "0.10.3"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
72 changes: 68 additions & 4 deletions src/adj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
module AdjacencyLists

abstract type AbstractAdjacencyList{EndPointType} end
export AbstractAdjacencyList, IdWeight, IdIntWeight, AdjacencyList, StaticAdjacencyList, neighbors, add_edge!, add_vertex!,
IdOrder, WeightOrder, RevWeightOrder, sort_last_item!
export AbstractAdjacencyList, AdjacencyList, StaticAdjacencyList,
neighbors, add_edge!, add_edges!, add_vertex!, neighbors_length,
IdWeight, IdIntWeight,
sort_last_item!, IdOrder, WeightOrder, RevWeightOrder

using Base.Order
import Base.Order: lt

Base.eachindex(adj::AbstractAdjacencyList) = 1:length(adj)

function Base.iterate(adj::AbstractAdjacencyList, i::Int=1)
n = length(adj)
(n == 0 || i > n) && return nothing
@inbounds neighbors(adj, i), i+1
end

struct IdWeight
id::UInt32
weight::Float32
Expand Down Expand Up @@ -68,16 +76,22 @@ struct AdjacencyList{EndPointType} <: AbstractAdjacencyList{EndPointType}
glock::Threads.SpinLock
end

Base.eltype(adj::AdjacencyList{EndPointType}) where EndPointType = Vector{EndPointType}


function AdjacencyList(lists::Vector{Vector{EndPointType}}) where EndPointType
locks = [Threads.SpinLock() for _ in 1:length(lists)]
AdjacencyList{EndPointType}(lists, EndPointType[], locks, Threads.SpinLock())
end

function AdjacencyList(::Type{EndPointType}; n::Int=0) where EndPointType

function AdjacencyList(::Type{EndPointType}, n::Int) where EndPointType
lists = Vector{Vector{EndPointType}}(undef, n)
AdjacencyList(lists)
end

AdjacencyList(t::Type{EndPointType}; n::Int=0) where EndPointType = AdjacencyList(t, n::Int)

function Base.resize!(adj::AdjacencyList, n)
lock(adj.glock)

Expand All @@ -104,6 +118,11 @@ Base.@propagate_inbounds @inline function neighbors(adj::AdjacencyList, i::Integ
isassigned(adj.end_point, i) ? adj.end_point[i] : adj.empty_cent
end

Base.@propagate_inbounds @inline function neighbors_length(adj::AdjacencyList, i::Integer)
# we can access undefined posting lists, it is responsability of the algorithm to ensure this doesn't happens
isassigned(adj.end_point, i) ? length(adj.end_point[i]) : 0
end

Base.@propagate_inbounds @inline function add_edge!(adj::AdjacencyList{EndPointType}, i::Integer, end_point, order=nothing) where EndPointType
@inbounds lock(adj.locks[i])

Expand All @@ -122,6 +141,38 @@ Base.@propagate_inbounds @inline function add_edge!(adj::AdjacencyList{EndPointT
adj
end

Base.@propagate_inbounds @inline function add_edges!(adj::AdjacencyList{EndPointType}, i::Integer, neighbors::Vector{EndPointType}) where EndPointType
@inbounds lock(adj.locks[i])
try
if isassigned(adj.end_point, i)
push!(adj.end_point[i], neighbors)
else
adj.end_point[i] = neighbors
end
finally
@inbounds unlock(adj.locks[i])
end

adj
end

Base.@propagate_inbounds @inline function add_edges!(adj::AdjacencyList{EndPointType}, i::Integer, neighbors) where EndPointType
@inbounds lock(adj.locks[i])
try
if !isassigned(adj.end_point, i)
adj.end_point[i] = neighbors
end

for p in neighbors
push!(adj.end_point[i], p)
end
finally
@inbounds unlock(adj.locks[i])
end

adj
end

Base.@propagate_inbounds @inline function add_vertex!(adj::AdjacencyList{T}) where T
add_vertex!(adj, T[])
end
Expand All @@ -145,21 +196,34 @@ struct StaticAdjacencyList{EndPointType} <: AbstractAdjacencyList{EndPointType}
end

Base.length(adj::StaticAdjacencyList) = length(adj.offset)
Base.eltype(adj::StaticAdjacencyList{EndPointType}) where EndPointType = typeof(view(adj.end_point, 1:1))

function StaticAdjacencyList(adj::StaticAdjacencyList; offset=adj.offset, end_point=adj.end_point)
StaticAdjacencyList(offset, end_point)
end

Base.@propagate_inbounds @inline function neighbors(adj::StaticAdjacencyList, i::Integer)
@inbounds sp::Int64 = i == 1 ? 1 : adj.offset[i-1]+1
@inbounds sp::Int64 = i == 1 ? 1 : adj.offset[i-1] + 1
@inbounds ep = adj.offset[i]
view(adj.end_point, sp:ep)
end

Base.@propagate_inbounds @inline function neighbors_length(adj::StaticAdjacencyList, i::Integer)
@inbounds sp::Int64 = i == 1 ? 1 : adj.offset[i-1] + 1
@inbounds ep = adj.offset[i]
length(ep - sp + 1)
end



function add_edge!(adj::StaticAdjacencyList, i::Integer, end_point)
error("ERROR: unsupported add_edge! on a static adjacent list")
end

function add_edges!(adj::StaticAdjacencyList, i::Integer, neighbors)
error("ERROR: unsupported add_edges! on a static adjacent list")
end

function add_vertex!(adj::StaticAdjacencyList)
error("ERROR: unsupported add_vertext! on a static adjacent list")
end
Expand Down
5 changes: 2 additions & 3 deletions src/knnresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ end
Access the i-th item in `res`
"""
@inline function Base.getindex(res::KnnResult, i)
@inbounds res.items[i]
end
@inline Base.getindex(res::KnnResult, i) = (@inbounds res.items[i])
@inline Base.setindex!(res::KnnResult, item::IdWeight, i::Integer) = (res.items[i] = item)

@inline Base.last(res::KnnResult) = last(res.items)
@inline Base.first(res::KnnResult) = @inbounds first(res.items)
Expand Down
5 changes: 3 additions & 2 deletions src/permindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ PermutedSearchIndex(; index, π, π′=invperm(π)) = PermutedSearchIndex(index,

function search(p::PermutedSearchIndex, q, res::KnnResult; pools=getpools(index))
out = search(p.index, q, res; pools)
@inbounds for i in eachindex(res.id)
res.id[i] = p.π[res.id[i]]
@inbounds for i in eachindex(res)
x = res[i]
res[i] = IdWeight(p.π[x.id], x.weight)
end

out
Expand Down

2 comments on commit 2b36054

@sadit
Copy link
Owner Author

@sadit sadit commented on 2b36054 Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78286

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.3 -m "<description of version>" 2b3605498772abf58c6228e581561b004ff60c45
git push origin v0.10.3

Please sign in to comment.