Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to get a list of supported pipeline tasks #14732

Merged
merged 1 commit into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/transformers/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from argparse import ArgumentParser

from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -63,9 +63,7 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument(
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional

from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -104,7 +104,7 @@ def register_subcommand(parser: ArgumentParser):
serve_parser.add_argument(
"--task",
type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
Expand Down Expand Up @@ -252,6 +252,15 @@
}


def get_supported_tasks() -> List[str]:
"""
Returns a list of supported task strings.
"""
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
supported_tasks.sort()
return supported_tasks


def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
Expand Down Expand Up @@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")

raise KeyError(
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
)
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")


def pipeline(
Expand Down