Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChoiceParameter to restrict parameter options to those specified. #1800

Merged
merged 11 commits into from
Aug 9, 2016
4 changes: 2 additions & 2 deletions luigi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
DateIntervalParameter, TimeDeltaParameter,
IntParameter, FloatParameter, BooleanParameter, BoolParameter,
TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter,
NumericalParameter
NumericalParameter, ChoiceParameter
)

from luigi import configuration
Expand All @@ -59,5 +59,5 @@
'FloatParameter', 'BooleanParameter', 'BoolParameter', 'TaskParameter',
'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter',
'configuration', 'interface', 'file', 'run', 'build', 'event', 'Event',
'NumericalParameter'
'NumericalParameter', 'ChoiceParameter'
]
54 changes: 54 additions & 0 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,57 @@ def parse(self, s):
raise ValueError(
"{s} is not in the set of {permitted_range}".format(
s=s, permitted_range=self._permitted_range))


class ChoiceParameter(Parameter):
"""
A parameter which takes two values:
1. an instance of :class:`~collections.Iterable` and
2. the class of the variables to convert to.

In the task definition, use

.. code-block:: python

class MyTask(luigi.Task):
my_param = luigi.ChoiceParameter(choices=[0.1, 0.2, 0.3], var_type=float)

At the command line, use

.. code-block:: console

$ luigi --module my_tasks MyTask --my-param 0.1

Consider using :class:`~luigi.EnumParameter` for a typed, structured
alternative. This class can perform the same role when all choices are the
same type and transparency of parameter value on the command line is
desired.
"""
def __init__(self, var_type=str, *args, **kwargs):
"""
:param function var_type: The type of the input variable, e.g. str, int,
float, etc.
Default: str
:param choices: An iterable, all of whose elements are of `var_type` to
restrict parameter choices to.
"""
if "choices" not in kwargs:
raise ParameterException("A choices iterable must be specified")
self._choices = set(kwargs.pop("choices"))
self._var_type = var_type
assert all(type(choice) is self._var_type for choice in self._choices), "Invalid type in choices"
super(ChoiceParameter, self).__init__(*args, **kwargs)
if self.description:
self.description += " "
else:
self.description = ""
self.description += (
"Choices: {" + ", ".join(str(choice) for choice in self._choices) + "}")

def parse(self, s):
var = self._var_type(s)
if var in self._choices:
return var
else:
raise ValueError("{s} is not a valid choice from {choices}".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is broken – can you fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed. Don't know how I didn't copy that last line.

s=s, choices=self._choices))
55 changes: 55 additions & 0 deletions test/choice_parameter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from helpers import unittest

import luigi


class ChoiceParameterTest(unittest.TestCase):
def test_parse_str(self):
d = luigi.ChoiceParameter(choices=["1", "2", "3"])
self.assertEqual("3", d.parse("3"))

def test_parse_int(self):
d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3])
self.assertEqual(3, d.parse(3))

def test_parse_int_conv(self):
d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3])
self.assertEqual(3, d.parse("3"))

def test_invalid_choice(self):
d = luigi.ChoiceParameter(choices=["1", "2", "3"])
self.assertRaises(ValueError, lambda: d.parse("xyz"))

def test_invalid_choice_type(self):
self.assertRaises(AssertionError, lambda: luigi.ChoiceParameter(var_type=int, choices=[1, 2, "3"]))

def test_choices_parameter_exception(self):
self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.ChoiceParameter(var_type=int))

def test_hash_str(self):
class Foo(luigi.Task):
args = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"])
p = luigi.ChoiceParameter(var_type=str, choices=["3", "2", "1"])
self.assertEqual(hash(Foo(args="3").args), hash(p.parse("3")))

def test_serialize_parse(self):
a = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"])
b = "3"
self.assertEqual(b, a.parse(a.serialize(b)))
5 changes: 5 additions & 0 deletions test/parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,11 @@ def testNumericalParameter(self):
p = luigi.NumericalParameter(min_value=-3, max_value=7, var_type=int, config_path=dict(section="foo", name="bar"))
self.assertEqual(-3, _value(p))

@with_config({"foo": {"bar": "3"}})
def testChoiceParameter(self):
p = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3], config_path=dict(section="foo", name="bar"))
self.assertEqual(3, _value(p))


class OverrideEnvStuff(LuigiTestCase):

Expand Down