Skip to content

Commit

Permalink
Add ability to get a list of supported pipeline tasks (#14732)
Browse files Browse the repository at this point in the history
  • Loading branch information
codesue authored Dec 13, 2021
1 parent 3d66146 commit c17e7cd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
6 changes: 2 additions & 4 deletions src/transformers/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from argparse import ArgumentParser

from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -63,9 +63,7 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument(
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional

from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -104,7 +104,7 @@ def register_subcommand(parser: ArgumentParser):
serve_parser.add_argument(
"--task",
type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
Expand Down Expand Up @@ -252,6 +252,15 @@
}


def get_supported_tasks() -> List[str]:
"""
Returns a list of supported task strings.
"""
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
supported_tasks.sort()
return supported_tasks


def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
Expand Down Expand Up @@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")

raise KeyError(
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
)
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")


def pipeline(
Expand Down

0 comments on commit c17e7cd

Please sign in to comment.