Skip to content

Commit

Permalink
Checkpoint on ext lib infra
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Ballance <[email protected]>
  • Loading branch information
mballance committed Jan 19, 2025
1 parent b8de048 commit 23a6013
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 296 deletions.
20 changes: 0 additions & 20 deletions src/dv_flow_mgr/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,6 @@ def getPackage(self, name : str) -> 'Package':
def getTaskCtor(self, name : str) -> TaskCtor:
return self.tasks[name]

def mkTaskParams(self, name : str) -> TaskParams:
if name not in self.tasks:
raise Exception("Task " + name + " not found")
return self.tasks[name].mkTaskParams()

def setTaskParams(self, name : str, params : TaskParams, pvals : Dict[str,Any]):
if name not in self.tasks:
raise Exception("Task " + name + " not found")
self.tasks[name].setTaskParams(params, pvals)

def mkTask(self,
name : str,
task_id : int,
session : 'Session',
params : TaskParams,
depends : List['Task']) -> 'Task':
# TODO: combine parameters to create the full taskname
task = self.tasks[name].mkTask(name, task_id, session, params, depends)
return task

def __hash__(self):
return hash(self.fullname())

237 changes: 138 additions & 99 deletions src/dv_flow_mgr/package_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import pydantic
import pydantic.dataclasses as dc
from pydantic import BaseModel
from typing import Any, Dict, List
from typing import Any, Dict, List, Callable, Tuple
from .flow import Flow
from .fragment_def import FragmentDef
from .package import Package
from .package_import_spec import PackageImportSpec, PackageSpec
from .task import TaskParamCtor, TaskCtorT, TaskParams
from .task import TaskCtor, TaskParams
from .task_def import TaskDef, TaskSpec
from .tasklib.builtin_pkg import TaskNull


class PackageDef(BaseModel):
Expand All @@ -56,108 +57,146 @@ def getTask(self, name : str) -> 'TaskDef':
def mkPackage(self, session, params : Dict[str,Any] = None) -> 'Package':
ret = Package(self.name)

for task in self.tasks:
if task.pyclass is not None:
# This task provides a Python implementation
# Our task is to create a TaskCtor that includes a
# parameters class and task class
if task.uses is not None:
# Find the base for this task
ctor_t = session.getTaskCtor(task_t, self)
else:
ctor_t = None

# Construct a composite set of parameters
# - Merge parameters from the base (if any) with local
field_m = {}
print("task.params: %s" % str(task.params))
ptype_m = {
"str" : str,
"int" : int,
"float" : float,
"bool" : bool
}
for p in task.params.keys():
param = task.params[p]
if "type" in param.keys():
ptype_s = param["type"]
if ptype_s not in ptype_m.keys():
raise Exception("Unknown type %s" % ptype_s)
ptype = ptype_m[ptype_s]

if p in field_m.keys():
raise Exception("Duplicate field %s" % p)
if "value" in param.keys():
field_m[p] = (ptype, param["value"])
else:
field_m[p] = (ptype, )
else:
if p not in field_m.keys():
raise Exception("Field %s not found" % p)
if "value" not in param.keys():
raise Exception("No value specified for param %p" % p)
field_m[p] = (field_m[p][0], params["value"])
task_p = pydantic.create_model("Task%sParams" % task.name, **field_m)

# Now, lookup the class
last_dot = task.pyclass.rfind('.')
clsname = task.pyclass[last_dot+1:]
modname = task.pyclass[:last_dot]

try:
if modname not in sys.modules:
if self.basedir not in sys.path:
sys.path.append(self.basedir)
mod = importlib.import_module(modname)
else:
mod = sys.modules[modname]
except ModuleNotFoundError as e:
raise Exception("Failed to import module %s" % modname)

if not hasattr(mod, clsname):
raise Exception("Class %s not found in module %s" % (clsname, modname))
cls = getattr(mod, clsname)

ctor_t = TaskCtorT(task_p, cls)
elif task.uses is None:
# We use the built-in Null task
pass
else:
# Find package (not package_def) that implements this task
# Insert an indirect reference to that tasks's constructor
session.push_package(ret)

# Only call getTaskCtor if the task is in a different package
task_t = task.uses if isinstance(task.uses, TaskSpec) else TaskSpec(task.uses)
ctor_t = session.getTaskCtor(task_t, self)
tasks_m : Dict[str,TaskCtor]= {}

ctor_t = TaskParamCtor(
base=ctor_t,
params=task.params,
basedir=self.basedir,
depend_refs=task.depends)
ret.tasks[task.name] = ctor_t
for task in self.tasks:
if task.name in tasks_m.keys():
raise Exception("Duplicate task %s" % task.name)
tasks_m[task.name] = (task, self.basedir, ) # We'll add a TaskCtor later

for frag in self.fragment_l:
for task in frag.tasks:
if task.uses is not None:
# Find package (not package_def) that implements this task
# Insert an indirect reference to that tasks's constructor

# Only call getTaskCtor if the task is in a different package
task_t = task.uses if isinstance(task.uses, TaskSpec) else TaskSpec(task.uses)
ctor_t = session.getTaskCtor(task_t, self)

ctor_t = TaskParamCtor(
base=ctor_t,
params=task.params,
basedir=frag.basedir,
depend_refs=task.depends)
else:
# We use the Null task from the std package
raise Exception("")
if task.name in ret.tasks:
raise Exception("Task %s already defined" % task.name)
ret.tasks[task.name] = ctor_t
if task.name in tasks_m.keys():
raise Exception("Duplicate task %s" % task.name)
tasks_m[task.name] = (task, frag.basedir, ) # We'll add a TaskCtor later

# Now we have a unified map of the tasks declared in this package
for name in list(tasks_m.keys()):
task_i = tasks_m[name]
if len(task_i) < 3:
# Need to create the task ctor
ctor_t = self.mkTaskCtor(session, task_i[0], task_i[1], tasks_m)
tasks_m[name] = (task_i[0], task_i[1], ctor_t)
ret.tasks[name] = tasks_m[name][2]

session.pop_package(ret)

return ret

def mkTaskCtor(self, session, task, basedir, tasks_m) -> TaskCtor:
ctor_t : TaskCtor = None

if task.uses is not None:
# Find package (not package_def) that implements this task
# Insert an indirect reference to that tasks's constructor
last_dot = task.uses.rfind('.')

if last_dot != -1:
pkg_name = task.uses[:last_dot]
task_name = task.uses[last_dot+1:]
else:
pkg_name = None
task_name = task.uses

if pkg_name is not None:
pkg = session.getPackage(PackageSpec(pkg_name))
if pkg is None:
raise Exception("Failed to find package %s" % pkg_name)
ctor_t = pkg.getTaskCtor(task_name)
else:
if task_name not in tasks_m.keys():
raise Exception("Failed to find task %s" % task_name)
if len(tasks_m[task_name]) == 3:
ctor_t = tasks_m[task_name][2]
else:
task_i = tasks_m[task_name]
ctor_t = self.mkTaskCtor(session, task_i[0], task_i[1], tasks_m)
tasks_m[task_name] = ctor_t

if ctor_t is None:
# Provide a default implementation
ctor_t = TaskCtor(
task_ctor=TaskNull,
param_ctor=TaskParams)

if task.pyclass is not None:
# Built-in impl
# Now, lookup the class
last_dot = task.pyclass.rfind('.')
clsname = task.pyclass[last_dot+1:]
modname = task.pyclass[:last_dot]

try:
if modname not in sys.modules:
if self.basedir not in sys.path:
sys.path.append(self.basedir)
mod = importlib.import_module(modname)
else:
mod = sys.modules[modname]
except ModuleNotFoundError as e:
raise Exception("Failed to import module %s" % modname)

if not hasattr(mod, clsname):
raise Exception("Class %s not found in module %s" % (clsname, modname))
ctor_t.task_ctor = getattr(mod, clsname)

if task.uses is None:
ctor_t.param_ctor = TaskParams

decl_params = False
for value in task.params.values():
if "type" in value:
decl_params = True
break

if decl_params:
# We need to combine base parameters with new parameters
field_m = {}
# First, add parameters from the base class
for fname,info in ctor_t.param_ctor.model_fields.items():
print("Field: %s (%s)" % (fname, info.default))
field_m[fname] = (info.annotation, info.default)
ptype_m = {
"str" : str,
"int" : int,
"float" : float,
"bool" : bool
}
for p in task.params.keys():
param = task.params[p]
if type(param) == dict and "type" in param.keys():
ptype_s = param["type"]
if ptype_s not in ptype_m.keys():
raise Exception("Unknown type %s" % ptype_s)
ptype = ptype_m[ptype_s]

if p in field_m.keys():
raise Exception("Duplicate field %s" % p)
if "value" in param.keys():
field_m[p] = (ptype, param["value"])
else:
field_m[p] = (ptype, )
else:
if p not in field_m.keys():
raise Exception("Field %s not found" % p)
if type(param) != dict:
value = param
elif "value" in param.keys():
value = param["value"]
else:
raise Exception("No value specified for param %s: %s" % (
p, str(param)))
field_m[p] = (field_m[p][0], value)
print("field_m: %s" % str(field_m))
ctor_t.param_ctor = pydantic.create_model(
"Task%sParams" % task.name, **field_m)
else:
if len(task.params) > 0:
ctor_t.params = task.params
if len(task.depends) > 0:
ctor_t.depends.extends(task.depends)

return ctor_t

Loading

0 comments on commit 23a6013

Please sign in to comment.