Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for Tasks as a data-dependency in a DataFlowTask #80

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ JuliaSysimage.dll
Manifest.toml
TiledFactorization
docs/.DS_Store
docs/slides/**
23 changes: 11 additions & 12 deletions src/dag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
function update_edges!(dag::DAG, nodej)
transitively_connected = dag._buffer
empty!(transitively_connected)
hastask = any(x -> isa(x, Task), data(nodej))
# update dependencies from newer to older and reinfornce transitivity by
# skipping predecessors of nodes which are already connected
for (nodei, _) in Iterators.reverse(dag)
Expand All @@ -189,21 +190,19 @@
end
ti = tag(nodei)
(ti ∈ transitively_connected) && continue
# if a DataFlowTask is in data, add the edge directly to the DAG
@assert nodei ≤ nodej "i = $(nodei.tag), j = $(nodej.tag)"
dep = data_dependency(nodei, nodej)
dep || continue
# tasks are handled differently when they appear in the data in that
# they are checked directly agains the nodej.task field
dep = false
if hastask
for d in data(nodej)
d === nodei.task && (dep = true; break)
end

Check warning on line 199 in src/dag.jl

View check run for this annotation

Codecov / codecov/patch

src/dag.jl#L199

Added line #L199 was not covered by tests
end
dep || data_dependency(nodei, nodej) || continue
addedge!(dag, nodei, nodej)
update_transitively_connected!(transitively_connected, nodei, dag)
# addedge_transitive!(dag,nodei,nodej)
end
# if a DataFlowTask is in data and it is still active, add the edge directly to the DAG
for d in data(nodej)
(d isa DataFlowTask) &&
(tag(d) ∉ transitively_connected) &&
haskey(dag.inoutlist, d) &&
addedge!(dag, d, nodej)
end

return dag
end

Expand Down
4 changes: 2 additions & 2 deletions src/dataflowtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ end

@noinline function _data_dependency(datai, modei, dataj, modej)
for (di, mi) in zip(datai, modei)
(di isa DataFlowTask) && continue
(di isa Task) && continue # Tasks are handled differently
for (dj, mj) in zip(dataj, modej)
(dj isa DataFlowTask) && continue
(dj isa Task) && continue # Tasks are handled differently
mi == READ && mj == READ && continue
if memory_overlap(di, dj)
return true
Expand Down
2 changes: 1 addition & 1 deletion src/taskgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function stop_dag_cleaner(tg::TaskGraph = get_active_taskgraph())
return tg.dag_cleaner
else # expected result, task is running
put!(tg.finished, Stop())
# wait for t to stop before continuining
# wait for t to stop before continuing
wait(t)
end
return tg.dag_cleaner
Expand Down
15 changes: 15 additions & 0 deletions test/dataflowtask_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,18 @@ end
@test typeof(s) == Task
@inferred test_seq_mode(x)
end

@testset "Fetching task" begin
DataFlowTasks.set_active_taskgraph!(DataFlowTasks.TaskGraph())
d1 = @dspawn (sleep(0.01); rand(10)) label = "sleep"
d2 = @dspawn fill!(fetch(@R(d1)), 0) label = "fill"
@test fetch(d2) |> sum == 0
# make sure that d2 depends on d1 by checking the length of the critical
# path
log_info = DataFlowTasks.@log begin
d1 = @dspawn (sleep(0.01); rand(10)) label = "sleep"
d2 = @dspawn fill!(fetch(@R(d1)), 0) label = "fill"
fetch(d2)
end
@test length(DataFlowTasks.longest_path(log_info)) == 2
end