diff --git a/ceci/config.py b/ceci/config.py index d805107..8828364 100644 --- a/ceci/config.py +++ b/ceci/config.py @@ -156,6 +156,15 @@ def __repr__(self): s += self.__str__() return s + def to_dict(self): + """ Forcibly return a dict where the values have been cast from StageParameter """ + return {key:cast_to_streamable(value) for key, value in dict.items(self)} + + def __iter__(self): + """ Override the __iter__ to work with `StageParameter` """ + d = self.to_dict() + return iter(d) + def __getitem__(self, key): """ Override the __getitem__ to work with `StageParameter` """ attr = dict.__getitem__(self, key) @@ -187,7 +196,7 @@ def items(self): def values(self): """ Override values() to get the parameters values instead of the objects """ return [cast_to_streamable(value) for value in dict.values(self)] - + def set_config(self, input_config, args): """ Utility function to load configuration diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 751251f..cce88b0 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -23,7 +23,7 @@ def test_config(): config.free = 'dog' config.free = 42 - + try: config.chunk_rows = 'a' except TypeError: @@ -42,16 +42,28 @@ def test_config(): config['new_par'] = 'abc' assert config['new_par'] == 'abc' assert config.get_type('new_par') == str - + config.reset() assert config.chunk_rows == 5000 - + assert config.get_type('chunk_rows') == int values = config.values() for key, value in config.items(): #assert value == config[key].value assert value in values + + def check_func(cfg, **kwargs): + for k, v in kwargs.items(): + check_type = cfg.get_type(k) + if k is not None and v is not None: + assert check_type == type(v) + + check_func(config, **config) + + for k in iter(config): + assert k in config + @@ -64,14 +76,14 @@ def test_interactive_pipeline(): dry_pipe = Pipeline.read('tests/test.yml', dry_run=True) dry_pipe.run() - + pipe2 = Pipeline.interactive() overall_inputs = {'DM':'./tests/inputs/dm.txt', 'fiducial_cosmology':'./tests/inputs/fiducial_cosmology.txt'} inputs = overall_inputs.copy() inputs['metacalibration'] = True inputs['config'] = None - + pipe2.pipeline_files.update(**inputs) pipe2.build_stage(PZEstimationPipe) pipe2.build_stage(shearMeasurementPipe, apply_flag=False) @@ -93,11 +105,11 @@ def test_interactive_pipeline(): assert pipe2['WLGCCov'] == pipe2.WLGCCov rpr = repr(pipe2.WLGCCov.config) - + path = pipe2.pipeline_files.get_path('covariance_copy') assert pipe2.pipeline_files.get_tag(path) == 'covariance_copy' assert pipe2.pipeline_files.get_type('covariance_copy') == pipe2.WLGCCov.get_output_type('covariance') - + pipe2.run() @@ -111,19 +123,19 @@ def test_inter_pipe(): inputs['config'] = None pipe2.pipeline_files.update(**inputs) - + pipe2.build_stage(PZEstimationPipe, name='bob') assert isinstance(pipe2.bob, PZEstimationPipe) pipe2.remove_stage('bob') assert not hasattr(pipe2, 'bob') - - - - + + + + if __name__ == "__main__": test_config() test_interactive()