From e195a980bc8e9d42f3eb4ac134950977b9e5158f Mon Sep 17 00:00:00 2001 From: Cooper Gillan Date: Sun, 16 Aug 2020 04:30:24 -0500 Subject: [PATCH] Add type annotations for mlengine_operator_utils (#10297) Add type annotations, including a few changes to ensure the right types are passed through. Specifically, if region is not given, it must be provided in the DAG's default_args. --- .../cloud/utils/mlengine_operator_utils.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py index 4e3d4d190365a..2231b2d14641d 100644 --- a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py +++ b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py @@ -24,31 +24,35 @@ import json import os import re +from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar from urllib.parse import urlsplit import dill +from airflow import DAG from airflow.exceptions import AirflowException from airflow.operators.python import PythonOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator from airflow.providers.google.cloud.operators.mlengine import MLEngineStartBatchPredictionJobOperator - -def create_evaluate_ops(task_prefix, # pylint: disable=too-many-arguments - data_format, - input_paths, - prediction_path, - metric_fn_and_keys, - validate_fn, - batch_prediction_job_id=None, - project_id=None, - region=None, - dataflow_options=None, - model_uri=None, - model_name=None, - version_name=None, - dag=None, +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def create_evaluate_ops(task_prefix: str, # pylint: disable=too-many-arguments + data_format: str, + input_paths: List[str], + prediction_path: str, + metric_fn_and_keys: Tuple[T, Iterable[str]], + validate_fn: T, + batch_prediction_job_id: Optional[str] = None, + region: Optional[str] = None, + project_id: Optional[str] = None, + dataflow_options: Optional[Dict] = None, + model_uri: Optional[str] = None, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + dag: Optional[DAG] = None, py_interpreter="python3"): """ Creates Operators needed for model evaluation and returns. @@ -186,6 +190,9 @@ def validate_err_and_count(summary): :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, PythonOperator) """ + batch_prediction_job_id = batch_prediction_job_id or "" + dataflow_options = dataflow_options or {} + region = region or "" # Verify that task_prefix doesn't have any special characters except hyphen # '-', which is the only allowed non-alphanumeric character by Dataflow. @@ -203,7 +210,7 @@ def validate_err_and_count(summary): if dag is not None and dag.default_args is not None: default_args = dag.default_args project_id = project_id or default_args.get('project_id') - region = region or default_args.get('region') + region = region or default_args['region'] model_name = model_name or default_args.get('model_name') version_name = version_name or default_args.get('version_name') dataflow_options = dataflow_options or \