Skip to content

Commit

Permalink
Updated .gitignore
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Jan 4, 2022
1 parent c15d4fa commit 45f34a2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tests/cwl
nb/Untitled.ipynb
.ipynb_checkpoints/
.coverage
.coverage.*
.eggs
out/
test/
Expand Down
83 changes: 63 additions & 20 deletions ceci/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,31 @@ def remove_stage(self, name):
self.stage_names.remove(name)
del self.stage_execution_config[name]


def get_stage_aliases(self, stage_name, stages_config=None):
"""Get the aliases for a particular stage
Parameters
----------
stage_name : str
The name of the stage in question
stages_config : Mapping or None
Configurtion dictionary, used if stages have not been created yet
Returns
-------
aliases : Mapping
The aliases
"""
stage = None
sec = self.stage_execution_config.get(stage_name)
if sec is not None:
stage = sec.stage_obj
if stage is not None:
return stage.config.get('aliases', {})
return stages_config.get(stage_name, {}).get('aliases', {})


def ordered_stages(self, overall_inputs, stages_config=None):
"""Produce a linear ordering for the stages.
Expand Down Expand Up @@ -587,6 +612,7 @@ def ordered_stages(self, overall_inputs, stages_config=None):
else:
stage_config_data = {}

# First pass, get the classes for all the stages
stage_classes = []
for stage_name in stage_names:
sec = self.stage_execution_config[stage_name]
Expand All @@ -595,14 +621,19 @@ def ordered_stages(self, overall_inputs, stages_config=None):
n = len(stage_names)

# Check for a pipeline output that is already given as an input
for stage in stage_classes:
for tag in stage.output_tags():
if tag in overall_inputs:
for stage_name in stage_names:
stage_class = self.stage_execution_config[stage_name].stage_class
stage_aliases = self.get_stage_aliases(stage_name, stage_config_data)
for tag in stage_class.output_tags():
aliased_tag = stage_aliases.get(tag, tag)
if aliased_tag in overall_inputs:
raise ValueError(
f"Pipeline stage {stage.instance_name} "
f"generates output {tag}, but "
f"Pipeline stage {stage_name} "
f"generates output {aliased_tag}, but "
"it is already an overall input"
)

# Now check that the stage names are unique
stage_set = set(stage_names)
if len(stage_set) < len(stage_classes):
raise ValueError("Some stages are included twice in your pipeline")
Expand All @@ -611,31 +642,39 @@ def ordered_stages(self, overall_inputs, stages_config=None):
# as an input. This is the equivalent of the adjacency matrix
# in graph-speak
dependencies = collections.defaultdict(list)
for stage in stage_classes:
for tag in stage.input_tags():
dependencies[tag].append(stage)
for stage_name in stage_names:
stage_class = self.stage_execution_config[stage_name].stage_class
stage_aliases = self.get_stage_aliases(stage_name, stage_config_data)
for tag in stage_class.input_tags():
aliased_tag = stage_aliases.get(tag, tag)
dependencies[aliased_tag].append(stage_name)

# count the number of inputs required by each stage
missing_input_counts = {stage: len(stage.inputs) for stage in stage_classes}
missing_input_counts = {}
for stage_name in stage_names:
stage_class = self.stage_execution_config[stage_name].stage_class
missing_input_counts[stage_name] = len(stage_class.inputs)

found_inputs = set()
# record the stages which are receiving overall inputs
for tag in overall_inputs:
found_inputs.add(tag)
for stage in dependencies[tag]:
missing_input_counts[stage] -= 1
for stage_name in dependencies[tag]:
missing_input_counts[stage_name] -= 1

# find all the stages that are ready because they have no missing inputs
queue = [stage for stage in stage_classes if missing_input_counts[stage] == 0]
queue = [stage_name for stage_name in stage_names if missing_input_counts[stage_name] == 0]
ordered_stages = []

all_inputs = overall_inputs.copy()

# make the ordering
while queue:
# get the next stage that has no inputs missing
stage_class = queue.pop()
sec = self.stage_execution_config[stage_class.name]
stage_config = stage_config_data.get(stage_class.name, {})
stage_name = queue.pop()
sec = self.stage_execution_config[stage_name]
stage_class = sec.stage_class
stage_config = stage_config_data.get(stage_name, {})
stage_config.update(all_inputs)
stage_config['config'] = stages_config
if sec.stage_obj is None:
Expand All @@ -647,9 +686,10 @@ def ordered_stages(self, overall_inputs, stages_config=None):
stage_outputs = stage.find_outputs('.')
for tag in stage.output_tags():
# find all the next_stages that depend on that file
found_inputs.add(tag)
all_inputs[tag] = stage_outputs[tag]
for next_stage in dependencies[tag]:
aliased_tag = stage.get_aliased_tag(tag)
found_inputs.add(aliased_tag)
all_inputs[aliased_tag] = stage_outputs[aliased_tag]
for next_stage in dependencies[aliased_tag]:
# record that the next stage now has one less
# missing dependency
missing_input_counts[next_stage] -= 1
Expand All @@ -663,10 +703,13 @@ def ordered_stages(self, overall_inputs, stages_config=None):
# Try to diagnose it here.
if len(ordered_stages) != n:
stages_missing_inputs = [
stage for (stage, count) in missing_input_counts.items() if count > 0
stage_name for (stage_name, count) in missing_input_counts.items() if count > 0
]
msg1 = []
for stage in stages_missing_inputs:
for stage_name in stages_missing_inputs:
stage = self.stage_execution_config[stage_name].stage_obj
if stage is None:
raise ValueError("Object has not been created for {stage_name}")
missing_inputs = [
tag for tag in stage.input_tags() if tag not in found_inputs
]
Expand Down
2 changes: 1 addition & 1 deletion ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def load_configs(self, args):
raise ValueError(
f"""
Missing these names on the command line:
{self.instance_name} Missing these names on the command line:
Input names: {missing_inputs}"""
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_config():
assert config.get_type('chunk_rows') == int


def test_interactive():
def test_interactive_pipeline():

# Load the pipeline interactively, this is just a temp fix to
# get the run_config and stage_config
Expand Down

0 comments on commit 45f34a2

Please sign in to comment.