diff --git a/luigi/__init__.py b/luigi/__init__.py index 5438b610c9..52857804aa 100644 --- a/luigi/__init__.py +++ b/luigi/__init__.py @@ -36,7 +36,7 @@ DateIntervalParameter, TimeDeltaParameter, IntParameter, FloatParameter, BooleanParameter, BoolParameter, TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, - NumericalParameter + NumericalParameter, ChoiceParameter ) from luigi import configuration @@ -59,5 +59,5 @@ 'FloatParameter', 'BooleanParameter', 'BoolParameter', 'TaskParameter', 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'configuration', 'interface', 'file', 'run', 'build', 'event', 'Event', - 'NumericalParameter' + 'NumericalParameter', 'ChoiceParameter' ] diff --git a/luigi/parameter.py b/luigi/parameter.py index af86e10f3a..5487bbdbc7 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -1034,3 +1034,57 @@ def parse(self, s): raise ValueError( "{s} is not in the set of {permitted_range}".format( s=s, permitted_range=self._permitted_range)) + + +class ChoiceParameter(Parameter): + """ + A parameter which takes two values: + 1. an instance of :class:`~collections.Iterable` and + 2. the class of the variables to convert to. + + In the task definition, use + + .. code-block:: python + + class MyTask(luigi.Task): + my_param = luigi.ChoiceParameter(choices=[0.1, 0.2, 0.3], var_type=float) + + At the command line, use + + .. code-block:: console + + $ luigi --module my_tasks MyTask --my-param 0.1 + + Consider using :class:`~luigi.EnumParameter` for a typed, structured + alternative. This class can perform the same role when all choices are the + same type and transparency of parameter value on the command line is + desired. + """ + def __init__(self, var_type=str, *args, **kwargs): + """ + :param function var_type: The type of the input variable, e.g. str, int, + float, etc. + Default: str + :param choices: An iterable, all of whose elements are of `var_type` to + restrict parameter choices to. + """ + if "choices" not in kwargs: + raise ParameterException("A choices iterable must be specified") + self._choices = set(kwargs.pop("choices")) + self._var_type = var_type + assert all(type(choice) is self._var_type for choice in self._choices), "Invalid type in choices" + super(ChoiceParameter, self).__init__(*args, **kwargs) + if self.description: + self.description += " " + else: + self.description = "" + self.description += ( + "Choices: {" + ", ".join(str(choice) for choice in self._choices) + "}") + + def parse(self, s): + var = self._var_type(s) + if var in self._choices: + return var + else: + raise ValueError("{s} is not a valid choice from {choices}".format( + s=s, choices=self._choices)) diff --git a/test/choice_parameter_test.py b/test/choice_parameter_test.py new file mode 100644 index 0000000000..f47b227f69 --- /dev/null +++ b/test/choice_parameter_test.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import unittest + +import luigi + + +class ChoiceParameterTest(unittest.TestCase): + def test_parse_str(self): + d = luigi.ChoiceParameter(choices=["1", "2", "3"]) + self.assertEqual("3", d.parse("3")) + + def test_parse_int(self): + d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3]) + self.assertEqual(3, d.parse(3)) + + def test_parse_int_conv(self): + d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3]) + self.assertEqual(3, d.parse("3")) + + def test_invalid_choice(self): + d = luigi.ChoiceParameter(choices=["1", "2", "3"]) + self.assertRaises(ValueError, lambda: d.parse("xyz")) + + def test_invalid_choice_type(self): + self.assertRaises(AssertionError, lambda: luigi.ChoiceParameter(var_type=int, choices=[1, 2, "3"])) + + def test_choices_parameter_exception(self): + self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.ChoiceParameter(var_type=int)) + + def test_hash_str(self): + class Foo(luigi.Task): + args = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"]) + p = luigi.ChoiceParameter(var_type=str, choices=["3", "2", "1"]) + self.assertEqual(hash(Foo(args="3").args), hash(p.parse("3"))) + + def test_serialize_parse(self): + a = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"]) + b = "3" + self.assertEqual(b, a.parse(a.serialize(b))) diff --git a/test/parameter_test.py b/test/parameter_test.py index 9e15c6345e..b712b6d9d5 100644 --- a/test/parameter_test.py +++ b/test/parameter_test.py @@ -839,6 +839,11 @@ def testNumericalParameter(self): p = luigi.NumericalParameter(min_value=-3, max_value=7, var_type=int, config_path=dict(section="foo", name="bar")) self.assertEqual(-3, _value(p)) + @with_config({"foo": {"bar": "3"}}) + def testChoiceParameter(self): + p = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3], config_path=dict(section="foo", name="bar")) + self.assertEqual(3, _value(p)) + class OverrideEnvStuff(LuigiTestCase):