Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher level incident call #60

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ DerivOp = Symbol("∂ₜ")
append_dot(s::Symbol) = Symbol(string(s)*'\U0307')

include("acset.jl")
include("query.jl")
include("language.jl")
include("composition.jl")
include("collages.jl")
Expand Down
10 changes: 2 additions & 8 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ See also: [`infer_terminals`](@ref).
"""
function infer_states(d::SummationDecapode)
parentless = filter(parts(d, :Var)) do v
length(incident(d, v, :tgt)) == 0 &&
length(incident(d, v, :res)) == 0 &&
length(incident(d, v, :sum)) == 0 &&
d[v, :type] != :Literal
!is_var_target(d, v) && d[v, :type] != :Literal
end
parents_of_tvars =
union(d[incident(d,:∂ₜ, :op1), :src],
Expand All @@ -259,10 +256,7 @@ See also: [`infer_states`](@ref).
"""
function infer_terminals(d::SummationDecapode)
filter(parts(d, :Var)) do v
length(incident(d, v, :src)) == 0 &&
length(incident(d, v, :proj1)) == 0 &&
length(incident(d, v, :proj2)) == 0 &&
length(incident(d, v, :summand)) == 0
!is_var_source(d, v)
end
end

Expand Down
1 change: 1 addition & 0 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, s

include("deca_acset.jl")
include("deca_visualization.jl")
include("deca_query.jl")

""" function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64})

Expand Down
25 changes: 25 additions & 0 deletions src/deca/deca_query.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using DiagrammaticEquations
using ACSets

export is_var_target, is_var_source, get_variable_parents, get_next_op1s, get_next_op2s

function is_var_target(d::SummationDecapode, var::Int)
return !isempty(collected_incident(d, var, [:tgt, :res, :sum]))
end

function is_var_source(d::SummationDecapode, var::Int)
return !isempty(collected_incident(d, var, [:src, :proj1, :proj2, :summand]))
end

function get_variable_parents(d::SummationDecapode, var::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename all these java getters to remove get_

return collected_incident(d, var, [:tgt, :res, :res, [:summation, :sum]], [:src, :proj1, :proj2, :summand])

Check warning on line 15 in src/deca/deca_query.jl

View check run for this annotation

Codecov / codecov/patch

src/deca/deca_query.jl#L14-L15

Added lines #L14 - L15 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea what this is supposed to do, is this supposed to be a conjucnction of disjunctions type query?

end

function get_next_op1s(d::SummationDecapode, var::Int)
collected_incident(d, var, [:src])

Check warning on line 19 in src/deca/deca_query.jl

View check run for this annotation

Codecov / codecov/patch

src/deca/deca_query.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end

function get_next_op2s(d::SummationDecapode, var::Int)
collected_incident(d, var, [:proj1, :proj2])

Check warning on line 23 in src/deca/deca_query.jl

View check run for this annotation

Codecov / codecov/patch

src/deca/deca_query.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
end

41 changes: 41 additions & 0 deletions src/query.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using DiagrammaticEquations
using ACSets

export collected_incident

function collected_incident(d::ACSet, searches::AbstractVector, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring on exported function


isempty(searches) && error("Cannot have an empty search")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't empty search return empty list?


query_result = mapreduce(vcat, searches) do search
collected_incident(d, search, args...)
end

return unique!(query_result)
end

function collected_incident(d::ACSet, search, lookup_array)
numof_channels = length(lookup_array)
empty_outputchannels = fill(nothing, numof_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this an array of nothings? You can't push something into these "channels"

return collected_incident(d, search, lookup_array, empty_outputchannels)
end


function collected_incident(d::ACSet, search, lookup_array, output_array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the type of search?

length(lookup_array) == length(output_array) || error("Input and output channels are different lengths")
isempty(lookup_array) && error("Cannot have an empty lookup")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should empty lookup_array return an empty list? empty search => empty results makes sense to me.


query_result = mapreduce(vcat, zip(lookup_array, output_array)) do (lookup, output)
runincident_output_result(d, search, lookup, output)
end

return unique!(query_result)
end

function runincident_output_result(d::ACSet, search, lookup::Union{Symbol, AbstractVector{Symbol}}, output_channel::Union{Symbol, Nothing})
index_result = incident(d, search, lookup)
isnothing(output_channel) ? index_result : d[index_result, output_channel]
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file feels over engineered for something simple, but I can't tell what it is supposed to do because of the lack of docstrings.



2 changes: 1 addition & 1 deletion test/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ trivial_comp_from_vector = oapply(trivial_relation, [otrivial])
trivial_comp_from_single = oapply(trivial_relation, otrivial)

# Test the oapply is correct.
@test apex(trivial_comp_from_vector) == Trivial
@test apex(trivial_comp_from_vector) == Trivial
@test apex(trivial_comp_from_single) == Trivial
# Test none of the decapodes were mutated
@test isequal(otrivial, deep_copies)
Expand Down
111 changes: 111 additions & 0 deletions test/deca_query.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
using Test
using DiagrammaticEquations
using DiagrammaticEquations.Deca
using ACSets

function array_contains_same(test, expected)
sort(test) == sort(expected)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you want to handle multiplicity? Could replace with setequal(a,b) = Set(a) == Set(b)

end

get_index_from_name(d::SummationDecapode, varname::Symbol) = only(incident(d, varname, :name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no java please. just call this index rely on dispatch to use the name because you gave it a symbol.


@testset "Check sources and targets" begin
singleton_deca = @decapode begin
V::infer
end

@test !is_var_source(singleton_deca, 1)
@test !is_var_target(singleton_deca, 1)


path_op1_deca = @decapode begin
(X,Z)::infer
X == d(d(Z))
end

idxX = get_index_from_name(path_op1_deca, :X)
idxZ = get_index_from_name(path_op1_deca, :Z)

@test is_var_source(path_op1_deca, idxZ)
@test !is_var_target(path_op1_deca, idxZ)

@test !is_var_source(path_op1_deca, idxX)
@test is_var_target(path_op1_deca, idxX)


path_op2_deca = @decapode begin
X == ∧(Y,Z)
end

idxX = get_index_from_name(path_op2_deca, :X)
idxY = get_index_from_name(path_op2_deca, :Y)
idxZ = get_index_from_name(path_op2_deca, :Z)

idxsYZ = [idxY, idxZ]

for idx in idxsYZ
@test is_var_source(path_op2_deca, idx)
@test !is_var_target(path_op2_deca, idx)
end
@test !is_var_source(path_op2_deca, idxX)
@test is_var_target(path_op2_deca, idxX)


path_sum_deca = @decapode begin
X == Y + Z
end

idxX = get_index_from_name(path_sum_deca, :X)
idxY = get_index_from_name(path_sum_deca, :Y)
idxZ = get_index_from_name(path_sum_deca, :Z)

idxsYZ = [idxY, idxZ]

for idx in idxsYZ
@test is_var_source(path_sum_deca, idx)
@test !is_var_target(path_sum_deca, idx)
end
@test !is_var_source(path_sum_deca, idxX)
@test is_var_target(path_sum_deca, idxX)

mixedop_deca = @decapode begin
Inter == d(X) + ∧(Y, Z)
Res == d(Inter)
end

idxX = get_index_from_name(mixedop_deca, :X)
idxY = get_index_from_name(mixedop_deca, :Y)
idxZ = get_index_from_name(mixedop_deca, :Z)
idxInter = get_index_from_name(mixedop_deca, :Inter)
idxRes = get_index_from_name(mixedop_deca, :Res)

@test is_var_source(mixedop_deca, idxX)
@test is_var_source(mixedop_deca, idxY)
@test is_var_source(mixedop_deca, idxZ)

@test is_var_target(mixedop_deca, idxRes)

@test is_var_target(mixedop_deca, idxInter) && is_var_source(mixedop_deca, idxInter)
end

# TODO: Finish writing these tests
@testset "Get states and terminals" begin
singleton_deca = @decapode begin
V::Form1
end
@test infer_state_names(singleton_deca) == infer_terminal_names(singleton_deca)

path_op1_deca = @decapode begin
(X,Z)::infer
X == d(d(Z))
end
@test array_contains_same(infer_state_names(path_op1_deca), [:Z])
@test array_contains_same(infer_terminal_names(path_op1_deca), [:X])

path_op2_deca = @decapode begin
(X,Y,Z)::infer
X == ∧(Y,Z)
end
@test array_contains_same(infer_state_names(path_op2_deca), [:Y, :Z])
@test array_contains_same(infer_terminal_names(path_op2_deca), [:X])
end
Loading
Loading