From c17e7cde326640a135bc7236a0e41ae52471cb90 Mon Sep 17 00:00:00 2001 From: Suzen Fylke Date: Mon, 13 Dec 2021 08:31:50 -0500 Subject: [PATCH] Add ability to get a list of supported pipeline tasks (#14732) --- src/transformers/commands/run.py | 6 ++---- src/transformers/commands/serving.py | 4 ++-- src/transformers/pipelines/__init__.py | 15 +++++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/commands/run.py b/src/transformers/commands/run.py index 563a086a7d87..dbf067ae4d95 100644 --- a/src/transformers/commands/run.py +++ b/src/transformers/commands/run.py @@ -14,7 +14,7 @@ from argparse import ArgumentParser -from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline +from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline from ..utils import logging from . import BaseTransformersCLICommand @@ -63,9 +63,7 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): @staticmethod def register_subcommand(parser: ArgumentParser): run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") - run_parser.add_argument( - "--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run" - ) + run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run") run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index dd2aec1f3aba..fbe77cb8d950 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -15,7 +15,7 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional -from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline +from ..pipelines import Pipeline, get_supported_tasks, pipeline from ..utils import logging from . import BaseTransformersCLICommand @@ -104,7 +104,7 @@ def register_subcommand(parser: ArgumentParser): serve_parser.add_argument( "--task", type=str, - choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), + choices=get_supported_tasks(), help="The task to run the pipeline on", ) serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 2f5ef55a78cd..2c4cc1688ead 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from ..configuration_utils import PretrainedConfig from ..feature_extraction_utils import PreTrainedFeatureExtractor @@ -252,6 +252,15 @@ } +def get_supported_tasks() -> List[str]: + """ + Returns a list of supported task strings. + """ + supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()) + supported_tasks.sort() + return supported_tasks + + def get_task(model: str, use_auth_token: Optional[str] = None) -> str: tmp = io.BytesIO() headers = {} @@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: return targeted_task, (tokens[1], tokens[3]) raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format") - raise KeyError( - f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}" - ) + raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}") def pipeline(