Skip to content

Commit

Permalink
Tweak search process
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Ballance <[email protected]>
  • Loading branch information
mballance committed Jan 24, 2025
1 parent ea66719 commit 1d81cbc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/dv_flow/mgr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from .package_def import *
from .pkg_rgy import PkgRgy
from .task_graph_runner import *
from .task import *
from .task_data import *
Expand Down
10 changes: 10 additions & 0 deletions src/dv_flow/mgr/task_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def getPackage(self, spec : PackageSpec) -> Package:
pkg = self._pkg_m[tgt_pkg_spec]
elif self.pkg_rgy.hasPackage(tgt_pkg_spec.name):
base = self.pkg_rgy.getPackage(tgt_pkg_spec.name)
pkg_m = self._pkg_m
self._pkg_m = {}
pkg = base.mkPackage(self, spec.params)
self._pkg_m = pkg_m
self._pkg_m[spec] = pkg
elif imp.path is not None:
# See if we can load the package
Expand All @@ -172,6 +175,9 @@ def getPackage(self, spec : PackageSpec) -> Package:
if base is None:
raise Exception("Failed to find imported package %s" % spec.name)
pkg = base.mkPackage(self, spec.params)
pkg_m = self._pkg_m
self._pkg_m = {}
self._pkg_m = pkg_m
self._pkg_m[spec] = pkg
break

Expand All @@ -180,7 +186,11 @@ def getPackage(self, spec : PackageSpec) -> Package:
p_def = self.pkg_rgy.getPackage(spec.name)

if p_def is not None:
pkg_m = self._pkg_m
self._pkg_m = {}
pkg = p_def.mkPackage(self)
self._pkg_m = pkg_m
self._pkg_m[spec] = pkg

if pkg is None:
raise Exception("Failed to find package %s from package %s" % (
Expand Down
29 changes: 24 additions & 5 deletions tests/sys/test_pkg_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import subprocess
import sys
from dv_flow.mgr import PkgRgy


def test_import_specific(tmpdir):
Expand Down Expand Up @@ -61,12 +62,12 @@ def test_import_alias(tmpdir):
name: p1
imports:
- name: p2
as: p3
- name: p2.foo
as: p2
tasks:
- name: my_task
uses: p3.doit
uses: p2.doit
"""

p2_flow_dv = """
Expand All @@ -77,7 +78,18 @@ def test_import_alias(tmpdir):
- name: doit
uses: std.Message
with:
msg: "Hello There"
msg: "Hello There (p2)"
"""

p2_foo_flow_dv = """
package:
name: p2.foo
tasks:
- name: doit
uses: std.Message
with:
msg: "Hello There (p2.foo)"
"""

rundir = os.path.join(tmpdir)
Expand All @@ -89,6 +101,13 @@ def test_import_alias(tmpdir):
with open(os.path.join(rundir, "p2/flow.dv"), "w") as fp:
fp.write(p2_flow_dv)

with open(os.path.join(rundir, "p2/foo.dv"), "w") as fp:
fp.write(p2_foo_flow_dv)

# pkg_rgy = PkgRgy()
# pkg_rgy.registerPackage("p2", os.path.join(rundir, "p2/flow.dv"))
# pkg_rgy.registerPackage("p2.foo", os.path.join(rundir, "p2/foo.dv"))

env = os.environ.copy()
env["DV_FLOW_PATH"] = rundir

Expand All @@ -104,5 +123,5 @@ def test_import_alias(tmpdir):

output = output.decode()

assert output.find("Hello There") != -1
assert output.find("Hello There (p2.foo)") != -1

0 comments on commit 1d81cbc

Please sign in to comment.