From 45f34a2c64cb0a60e848e31cf0668f1a08c16c1f Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Tue, 4 Jan 2022 15:09:03 -0800 Subject: [PATCH] Updated .gitignore --- .gitignore | 1 + ceci/pipeline.py | 83 +++++++++++++++++++++++++++++---------- ceci/stage.py | 2 +- tests/test_interactive.py | 2 +- 4 files changed, 66 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 01f5e6b..47069f2 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ tests/cwl nb/Untitled.ipynb .ipynb_checkpoints/ .coverage +.coverage.* .eggs out/ test/ diff --git a/ceci/pipeline.py b/ceci/pipeline.py index 5586686..0e15026 100644 --- a/ceci/pipeline.py +++ b/ceci/pipeline.py @@ -552,6 +552,31 @@ def remove_stage(self, name): self.stage_names.remove(name) del self.stage_execution_config[name] + + def get_stage_aliases(self, stage_name, stages_config=None): + """Get the aliases for a particular stage + + Parameters + ---------- + stage_name : str + The name of the stage in question + stages_config : Mapping or None + Configurtion dictionary, used if stages have not been created yet + + Returns + ------- + aliases : Mapping + The aliases + """ + stage = None + sec = self.stage_execution_config.get(stage_name) + if sec is not None: + stage = sec.stage_obj + if stage is not None: + return stage.config.get('aliases', {}) + return stages_config.get(stage_name, {}).get('aliases', {}) + + def ordered_stages(self, overall_inputs, stages_config=None): """Produce a linear ordering for the stages. @@ -587,6 +612,7 @@ def ordered_stages(self, overall_inputs, stages_config=None): else: stage_config_data = {} + # First pass, get the classes for all the stages stage_classes = [] for stage_name in stage_names: sec = self.stage_execution_config[stage_name] @@ -595,14 +621,19 @@ def ordered_stages(self, overall_inputs, stages_config=None): n = len(stage_names) # Check for a pipeline output that is already given as an input - for stage in stage_classes: - for tag in stage.output_tags(): - if tag in overall_inputs: + for stage_name in stage_names: + stage_class = self.stage_execution_config[stage_name].stage_class + stage_aliases = self.get_stage_aliases(stage_name, stage_config_data) + for tag in stage_class.output_tags(): + aliased_tag = stage_aliases.get(tag, tag) + if aliased_tag in overall_inputs: raise ValueError( - f"Pipeline stage {stage.instance_name} " - f"generates output {tag}, but " + f"Pipeline stage {stage_name} " + f"generates output {aliased_tag}, but " "it is already an overall input" ) + + # Now check that the stage names are unique stage_set = set(stage_names) if len(stage_set) < len(stage_classes): raise ValueError("Some stages are included twice in your pipeline") @@ -611,21 +642,28 @@ def ordered_stages(self, overall_inputs, stages_config=None): # as an input. This is the equivalent of the adjacency matrix # in graph-speak dependencies = collections.defaultdict(list) - for stage in stage_classes: - for tag in stage.input_tags(): - dependencies[tag].append(stage) + for stage_name in stage_names: + stage_class = self.stage_execution_config[stage_name].stage_class + stage_aliases = self.get_stage_aliases(stage_name, stage_config_data) + for tag in stage_class.input_tags(): + aliased_tag = stage_aliases.get(tag, tag) + dependencies[aliased_tag].append(stage_name) # count the number of inputs required by each stage - missing_input_counts = {stage: len(stage.inputs) for stage in stage_classes} + missing_input_counts = {} + for stage_name in stage_names: + stage_class = self.stage_execution_config[stage_name].stage_class + missing_input_counts[stage_name] = len(stage_class.inputs) + found_inputs = set() # record the stages which are receiving overall inputs for tag in overall_inputs: found_inputs.add(tag) - for stage in dependencies[tag]: - missing_input_counts[stage] -= 1 + for stage_name in dependencies[tag]: + missing_input_counts[stage_name] -= 1 # find all the stages that are ready because they have no missing inputs - queue = [stage for stage in stage_classes if missing_input_counts[stage] == 0] + queue = [stage_name for stage_name in stage_names if missing_input_counts[stage_name] == 0] ordered_stages = [] all_inputs = overall_inputs.copy() @@ -633,9 +671,10 @@ def ordered_stages(self, overall_inputs, stages_config=None): # make the ordering while queue: # get the next stage that has no inputs missing - stage_class = queue.pop() - sec = self.stage_execution_config[stage_class.name] - stage_config = stage_config_data.get(stage_class.name, {}) + stage_name = queue.pop() + sec = self.stage_execution_config[stage_name] + stage_class = sec.stage_class + stage_config = stage_config_data.get(stage_name, {}) stage_config.update(all_inputs) stage_config['config'] = stages_config if sec.stage_obj is None: @@ -647,9 +686,10 @@ def ordered_stages(self, overall_inputs, stages_config=None): stage_outputs = stage.find_outputs('.') for tag in stage.output_tags(): # find all the next_stages that depend on that file - found_inputs.add(tag) - all_inputs[tag] = stage_outputs[tag] - for next_stage in dependencies[tag]: + aliased_tag = stage.get_aliased_tag(tag) + found_inputs.add(aliased_tag) + all_inputs[aliased_tag] = stage_outputs[aliased_tag] + for next_stage in dependencies[aliased_tag]: # record that the next stage now has one less # missing dependency missing_input_counts[next_stage] -= 1 @@ -663,10 +703,13 @@ def ordered_stages(self, overall_inputs, stages_config=None): # Try to diagnose it here. if len(ordered_stages) != n: stages_missing_inputs = [ - stage for (stage, count) in missing_input_counts.items() if count > 0 + stage_name for (stage_name, count) in missing_input_counts.items() if count > 0 ] msg1 = [] - for stage in stages_missing_inputs: + for stage_name in stages_missing_inputs: + stage = self.stage_execution_config[stage_name].stage_obj + if stage is None: + raise ValueError("Object has not been created for {stage_name}") missing_inputs = [ tag for tag in stage.input_tags() if tag not in found_inputs ] diff --git a/ceci/stage.py b/ceci/stage.py index eb2c6cf..ea3dd8c 100644 --- a/ceci/stage.py +++ b/ceci/stage.py @@ -159,7 +159,7 @@ def load_configs(self, args): raise ValueError( f""" -Missing these names on the command line: +{self.instance_name} Missing these names on the command line: Input names: {missing_inputs}""" ) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 900c0ba..11ffbaa 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -49,7 +49,7 @@ def test_config(): assert config.get_type('chunk_rows') == int -def test_interactive(): +def test_interactive_pipeline(): # Load the pipeline interactively, this is just a temp fix to # get the run_config and stage_config