Skip to content

Commit

Permalink
Add guide for AI Platform (previously Machine Learning Engine) Operat…
Browse files Browse the repository at this point in the history
…ors (apache#9798)
  • Loading branch information
vuppalli authored and scrambldchannel committed Jul 17, 2020
1 parent 6a842c0 commit 7190bf3
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 3 deletions.
30 changes: 30 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
schedule_interval=None, # Override to match your needs
tags=['example'],
) as dag:
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator(
task_id="training",
project_id=PROJECT_ID,
Expand All @@ -74,26 +75,34 @@
training_python_module=TRAINER_PY_MODULE,
training_args=[],
)
# [END howto_operator_gcp_mlengine_training]

# [START howto_operator_gcp_mlengine_create_model]
create_model = MLEngineCreateModelOperator(
task_id="create-model",
project_id=PROJECT_ID,
model={
"name": MODEL_NAME,
},
)
# [END howto_operator_gcp_mlengine_create_model]

# [START howto_operator_gcp_mlengine_get_model]
get_model = MLEngineGetModelOperator(
task_id="get-model",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_get_model]

# [START howto_operator_gcp_mlengine_print_model]
get_model_result = BashOperator(
bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"",
task_id="get-model-result",
)
# [END howto_operator_gcp_mlengine_print_model]

# [START howto_operator_gcp_mlengine_create_version1]
create_version = MLEngineCreateVersionOperator(
task_id="create-version",
project_id=PROJECT_ID,
Expand All @@ -108,7 +117,9 @@
"pythonVersion": "3.7"
}
)
# [END howto_operator_gcp_mlengine_create_version1]

# [START howto_operator_gcp_mlengine_create_version2]
create_version_2 = MLEngineCreateVersionOperator(
task_id="create-version-2",
project_id=PROJECT_ID,
Expand All @@ -123,25 +134,33 @@
"pythonVersion": "3.7"
}
)
# [END howto_operator_gcp_mlengine_create_version2]

# [START howto_operator_gcp_mlengine_default_version]
set_defaults_version = MLEngineSetDefaultVersionOperator(
task_id="set-default-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version_name="v2",
)
# [END howto_operator_gcp_mlengine_default_version]

# [START howto_operator_gcp_mlengine_list_versions]
list_version = MLEngineListVersionsOperator(
task_id="list-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_list_versions]

# [START howto_operator_gcp_mlengine_print_versions]
list_version_result = BashOperator(
bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"",
task_id="list-version-result",
)
# [END howto_operator_gcp_mlengine_print_versions]

# [START howto_operator_gcp_mlengine_get_prediction]
prediction = MLEngineStartBatchPredictionJobOperator(
task_id="prediction",
project_id=PROJECT_ID,
Expand All @@ -152,20 +171,25 @@
input_paths=[PREDICTION_INPUT],
output_path=PREDICTION_OUTPUT,
)
# [END howto_operator_gcp_mlengine_get_prediction]

# [START howto_operator_gcp_mlengine_delete_version]
delete_version = MLEngineDeleteVersionOperator(
task_id="delete-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version_name="v1"
)
# [END howto_operator_gcp_mlengine_delete_version]

# [START howto_operator_gcp_mlengine_delete_model]
delete_model = MLEngineDeleteModelOperator(
task_id="delete-model",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
delete_contents=True
)
# [END howto_operator_gcp_mlengine_delete_model]

training >> create_version
training >> create_version_2
Expand All @@ -178,6 +202,7 @@
list_version >> delete_version
delete_version >> delete_model

# [START howto_operator_gcp_mlengine_get_metric]
def get_metric_fn_and_keys():
"""
Gets metric function and keys used to generate summary
Expand All @@ -186,7 +211,9 @@ def normalize_value(inst: Dict):
val = float(inst['dense_4'][0])
return tuple([val]) # returns a tuple.
return normalize_value, ['val'] # key order must match.
# [END howto_operator_gcp_mlengine_get_metric]

# [START howto_operator_gcp_mlengine_validate_error]
def validate_err_and_count(summary: Dict) -> Dict:
"""
Validate summary result
Expand All @@ -198,7 +225,9 @@ def validate_err_and_count(summary: Dict) -> Dict:
if summary['count'] != 20:
raise ValueError('Invalid value val != 20; summary={}'.format(summary))
return summary
# [END howto_operator_gcp_mlengine_validate_error]

# [START howto_operator_gcp_mlengine_evaluate]
evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
task_prefix="evaluate-ops",
data_format="TEXT",
Expand All @@ -218,6 +247,7 @@ def validate_err_and_count(summary: Dict) -> Dict:
version_name="v1",
py_interpreter="python3",
)
# [END howto_operator_gcp_mlengine_evaluate]

create_model >> create_version >> evaluate_prediction
evaluate_validation >> delete_version
36 changes: 36 additions & 0 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator):
"""
Start a Google Cloud ML Engine prediction job.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineStartBatchPredictionJobOperator`
NOTE: For model origin, users should consider exactly one from the
three options below:
Expand Down Expand Up @@ -351,6 +355,10 @@ class MLEngineCreateModelOperator(BaseOperator):
"""
Creates a new model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineCreateModelOperator`
The model should be provided by the `model` parameter.
:param model: A dictionary containing the information about the model.
Expand Down Expand Up @@ -395,6 +403,10 @@ class MLEngineGetModelOperator(BaseOperator):
"""
Gets a particular model
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineGetModelOperator`
The name of model shold be specified in `model_name`.
:param model_name: The name of the model.
Expand Down Expand Up @@ -438,6 +450,10 @@ class MLEngineDeleteModelOperator(BaseOperator):
"""
Deletes a model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineDeleteModelOperator`
The model should be provided by the `model_name` parameter.
:param model_name: The name of the model.
Expand Down Expand Up @@ -615,6 +631,10 @@ class MLEngineCreateVersionOperator(BaseOperator):
"""
Creates a new version in the model
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineCreateVersionOperator`
Model should be specified by `model_name`, in which case the `version` parameter should contain all the
information to create that version
Expand Down Expand Up @@ -678,6 +698,10 @@ class MLEngineSetDefaultVersionOperator(BaseOperator):
"""
Sets a version in the model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineSetDefaultVersionOperator`
The model should be specified by `model_name` to be the default. The name of the version should be
specified in the `version_name` parameter.
Expand Down Expand Up @@ -741,6 +765,10 @@ class MLEngineListVersionsOperator(BaseOperator):
"""
Lists all available versions of the model
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineListVersionsOperator`
The model should be specified by `model_name`.
:param model_name: The name of the Google Cloud ML Engine model that the version
Expand Down Expand Up @@ -794,6 +822,10 @@ class MLEngineDeleteVersionOperator(BaseOperator):
"""
Deletes the version from the model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineDeleteVersionOperator`
The name of the version should be specified in `version_name` parameter from the model specified
by `model_name`.
Expand Down Expand Up @@ -874,6 +906,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
"""
Operator for launching a MLEngine training job.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:MLEngineStartTrainingJobOperator`
:param job_id: A unique templated id for the submitted Google MLEngine
training job. (templated)
:type job_id: str
Expand Down
1 change: 0 additions & 1 deletion docs/build
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ MISSING_GOOGLLE_DOC_GUIDES = {
'datastore',
'dlp',
'gcs_to_bigquery',
'mlengine',
'mssql_to_gcs',
'mysql_to_gcs',
'postgres_to_gcs',
Expand Down
Loading

0 comments on commit 7190bf3

Please sign in to comment.