diff --git a/docs/get_started.ipynb b/docs/get_started.ipynb
index aa4962d4..e0ce5e79 100644
--- a/docs/get_started.ipynb
+++ b/docs/get_started.ipynb
@@ -6,7 +6,7 @@
"id": "KpXGE33umpig"
},
"source": [
- "\u003c!-- See: www.tensorflow.org/tfx/transform/ --\u003e\n",
+ "\n",
"\n",
"# Get Started with TensorFlow Transform"
]
@@ -68,8 +68,10 @@
},
"outputs": [],
"source": [
- "import pkg_resources\n",
"import importlib\n",
+ "\n",
+ "import pkg_resources\n",
+ "\n",
"importlib.reload(pkg_resources)"
]
},
@@ -87,10 +89,7 @@
"import tensorflow as tf\n",
"import tensorflow_transform as tft\n",
"import tensorflow_transform.beam as tft_beam\n",
- "\n",
- "from tensorflow_transform.tf_metadata import dataset_metadata\n",
- "from tensorflow_transform.tf_metadata import schema_utils\n",
- "\n",
+ "from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils\n",
"from tfx_bsl.public import tfxio"
]
},
@@ -105,7 +104,7 @@
"The *preprocessing function* is the most important concept of `tf.Transform`.\n",
"The preprocessing function is a logical description of a transformation of the\n",
"dataset. The preprocessing function accepts and returns a dictionary of tensors,\n",
- "where a *tensor* means `Tensor` or `SparseTensor`. There are two kinds of\n",
+ "where a *tensor* means `Tensor` or `SparseTensor`. There are three kinds of\n",
"functions used to define the preprocessing function:\n",
"\n",
"1. Any function that accepts and returns tensors. These add TensorFlow\n",
@@ -117,7 +116,11 @@
" over the entire dataset to generate a constant tensor that is returned as the\n",
" output. For example, `tft.min` computes the minimum of a tensor over the\n",
" dataset. `tf.Transform` provides a fixed set of analyzers, but this will be\n",
- " extended in future versions.\n"
+ " extended in future versions.\n",
+ "3. Any stateless [preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) (i.e. these layers must not invoke the ```adapt()``` method). These can be added as operations to the graph as they do not require a full pass over the data outside of the management of ```tf.Transform```. For example, you can add
\n",
+ "[tf.keras.layers.experimental.preprocessing.HashedCrossing](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/HashedCrossing),
\n",
+ "but not
\n",
+ "[tf.keras.layers.Normalization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization),
as the latter needs to be adapted over the entire dataset. Do note that if you use [Lambda layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Lambda), there are some de-serialization limitations which might prevent ```preprocessing_fn``` from being fully re-loaded off of disk by [tft.TFTransformOutput](https://www.tensorflow.org/tfx/transform/api_docs/python/tft/TFTransformOutput). \n"
]
},
{
@@ -143,19 +146,19 @@
"outputs": [],
"source": [
"def preprocessing_fn(inputs):\n",
- " x = inputs['x']\n",
- " y = inputs['y']\n",
- " s = inputs['s']\n",
- " x_centered = x - tft.mean(x)\n",
- " y_normalized = tft.scale_to_0_1(y)\n",
- " s_integerized = tft.compute_and_apply_vocabulary(s)\n",
- " x_centered_times_y_normalized = x_centered * y_normalized\n",
- " return {\n",
- " 'x_centered': x_centered,\n",
- " 'y_normalized': y_normalized,\n",
- " 'x_centered_times_y_normalized': x_centered_times_y_normalized,\n",
- " 's_integerized': s_integerized\n",
- " }"
+ " x = inputs[\"x\"]\n",
+ " y = inputs[\"y\"]\n",
+ " s = inputs[\"s\"]\n",
+ " x_centered = x - tft.mean(x)\n",
+ " y_normalized = tft.scale_to_0_1(y)\n",
+ " s_integerized = tft.compute_and_apply_vocabulary(s)\n",
+ " x_centered_times_y_normalized = x_centered * y_normalized\n",
+ " return {\n",
+ " \"x_centered\": x_centered,\n",
+ " \"y_normalized\": y_normalized,\n",
+ " \"x_centered_times_y_normalized\": x_centered_times_y_normalized,\n",
+ " \"s_integerized\": s_integerized,\n",
+ " }"
]
},
{
@@ -244,22 +247,26 @@
"outputs": [],
"source": [
"raw_data = [\n",
- " {'x': 1, 'y': 1, 's': 'hello'},\n",
- " {'x': 2, 'y': 2, 's': 'world'},\n",
- " {'x': 3, 'y': 3, 's': 'hello'}\n",
+ " {\"x\": 1, \"y\": 1, \"s\": \"hello\"},\n",
+ " {\"x\": 2, \"y\": 2, \"s\": \"world\"},\n",
+ " {\"x\": 3, \"y\": 3, \"s\": \"hello\"},\n",
"]\n",
"\n",
"raw_data_metadata = dataset_metadata.DatasetMetadata(\n",
- " schema_utils.schema_from_feature_spec({\n",
- " 'y': tf.io.FixedLenFeature([], tf.float32),\n",
- " 'x': tf.io.FixedLenFeature([], tf.float32),\n",
- " 's': tf.io.FixedLenFeature([], tf.string),\n",
- " }))\n",
+ " schema_utils.schema_from_feature_spec(\n",
+ " {\n",
+ " \"y\": tf.io.FixedLenFeature([], tf.float32),\n",
+ " \"x\": tf.io.FixedLenFeature([], tf.float32),\n",
+ " \"s\": tf.io.FixedLenFeature([], tf.string),\n",
+ " }\n",
+ " )\n",
+ ")\n",
"\n",
"with tft_beam.Context(temp_dir=tempfile.mkdtemp()):\n",
- " transformed_dataset, transform_fn = (\n",
- " (raw_data, raw_data_metadata) |\n",
- " tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))"
+ " transformed_dataset, transform_fn = (\n",
+ " raw_data,\n",
+ " raw_data_metadata,\n",
+ " ) | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)"
]
},
{
@@ -336,8 +343,20 @@
"outputs": [],
"source": [
"with tft_beam.Context(temp_dir=tempfile.mkdtemp()):\n",
- " transformed_data, transform_fn = (\n",
- " my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))"
+ " transformed_data, transform_fn = my_data | tft_beam.AnalyzeAndTransformDataset(\n",
+ " preprocessing_fn\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3ecf416c08f5"
+ },
+ "outputs": [],
+ "source": [
+ "transform_fn"
]
},
{
@@ -349,8 +368,8 @@
"outputs": [],
"source": [
"with tft_beam.Context(temp_dir=tempfile.mkdtemp()):\n",
- " transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)\n",
- " transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()"
+ " transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)\n",
+ " transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()"
]
},
{
@@ -386,7 +405,7 @@
"## Data Formats and Schema\n",
"\n",
"TFT Beam implementation accepts two different input data formats. The\n",
- "\"instance dict\" format (as seen in the example above and [simple.ipynb](https://www.tensorflow.org/tfx/tutorials/transform/simple) \u0026 [simple_example.py](https://github.com/tensorflow/transform/blob/master/examples/simple_example.py))\n",
+ "\"instance dict\" format (as seen in the example above and [simple.ipynb](https://www.tensorflow.org/tfx/tutorials/transform/simple) & [simple_example.py](https://github.com/tensorflow/transform/blob/master/examples/simple_example.py))\n",
"is an intuitive format and is suitable for small datasets while the TFXIO\n",
"([Apache Arrow](https://arrow.apache.org)) format provides improved performance\n",
"and is suitble for large datasets.\n",
@@ -424,10 +443,10 @@
"\n",
"Again, here is the definition of the schema for the example data:\n",
"\n",
- "\u003c!--\n",
+ ""
]
},
{
@@ -438,15 +457,17 @@
},
"outputs": [],
"source": [
- "from tensorflow_transform.tf_metadata import dataset_metadata\n",
- "from tensorflow_transform.tf_metadata import schema_utils\n",
+ "from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils\n",
"\n",
"raw_data_metadata = dataset_metadata.DatasetMetadata(\n",
- " schema_utils.schema_from_feature_spec({\n",
- " 's': tf.io.FixedLenFeature([], tf.string),\n",
- " 'y': tf.io.FixedLenFeature([], tf.float32),\n",
- " 'x': tf.io.FixedLenFeature([], tf.float32),\n",
- " }))"
+ " schema_utils.schema_from_feature_spec(\n",
+ " {\n",
+ " \"s\": tf.io.FixedLenFeature([], tf.string),\n",
+ " \"y\": tf.io.FixedLenFeature([], tf.float32),\n",
+ " \"x\": tf.io.FixedLenFeature([], tf.float32),\n",
+ " }\n",
+ " )\n",
+ ")"
]
},
{
@@ -482,10 +503,10 @@
"For tabular data, our Apache Beam implementation\n",
"accepts Arrow `RecordBatch`es that consist of columns of the following types:\n",
"\n",
- " - `pa.list_(\u003cprimitive\u003e)`, where `\u003cprimitive\u003e` is `pa.int64()`, `pa.float32()`\n",
+ " - `pa.list_()`, where `` is `pa.int64()`, `pa.float32()`\n",
" `pa.binary()` or `pa.large_binary()`.\n",
"\n",
- " - `pa.large_list(\u003cprimitive\u003e)`\n",
+ " - `pa.large_list()`\n",
"\n",
"The toy input dataset we used above, when represented as a `RecordBatch`, looks\n",
"like the following:"
@@ -503,12 +524,13 @@
"\n",
"raw_data = [\n",
" pa.record_batch(\n",
- " data=[\n",
- " pa.array([[1], [2], [3]], pa.list_(pa.float32())),\n",
- " pa.array([[1], [2], [3]], pa.list_(pa.float32())),\n",
- " pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),\n",
- " ],\n",
- " names=['x', 'y', 's'])\n",
+ " data=[\n",
+ " pa.array([[1], [2], [3]], pa.list_(pa.float32())),\n",
+ " pa.array([[1], [2], [3]], pa.list_(pa.float32())),\n",
+ " pa.array([[\"hello\"], [\"world\"], [\"hello\"]], pa.list_(pa.binary())),\n",
+ " ],\n",
+ " names=[\"x\", \"y\", \"s\"],\n",
+ " )\n",
"]"
]
},
@@ -540,9 +562,10 @@
"from tensorflow_metadata.proto.v0 import schema_pb2\n",
"\n",
"tensor_representation = {\n",
- " 'x': text_format.Parse(\n",
+ " \"x\": text_format.Parse(\n",
" \"\"\"dense_tensor { column_name: \"col1\" shape { dim { size: 2 } } }\"\"\",\n",
- " schema_pb2.TensorRepresentation())\n",
+ " schema_pb2.TensorRepresentation(),\n",
+ " )\n",
"}"
]
},
@@ -688,44 +711,55 @@
},
"outputs": [],
"source": [
- "#@title\n",
+ "# @title\n",
"ORDERED_CSV_COLUMNS = [\n",
- " 'age', 'workclass', 'fnlwgt', 'education', 'education-num',\n",
- " 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
- " 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label'\n",
+ " \"age\",\n",
+ " \"workclass\",\n",
+ " \"fnlwgt\",\n",
+ " \"education\",\n",
+ " \"education-num\",\n",
+ " \"marital-status\",\n",
+ " \"occupation\",\n",
+ " \"relationship\",\n",
+ " \"race\",\n",
+ " \"sex\",\n",
+ " \"capital-gain\",\n",
+ " \"capital-loss\",\n",
+ " \"hours-per-week\",\n",
+ " \"native-country\",\n",
+ " \"label\",\n",
"]\n",
"\n",
"CATEGORICAL_FEATURE_KEYS = [\n",
- " 'workclass',\n",
- " 'education',\n",
- " 'marital-status',\n",
- " 'occupation',\n",
- " 'relationship',\n",
- " 'race',\n",
- " 'sex',\n",
- " 'native-country',\n",
+ " \"workclass\",\n",
+ " \"education\",\n",
+ " \"marital-status\",\n",
+ " \"occupation\",\n",
+ " \"relationship\",\n",
+ " \"race\",\n",
+ " \"sex\",\n",
+ " \"native-country\",\n",
"]\n",
"\n",
"NUMERIC_FEATURE_KEYS = [\n",
- " 'age',\n",
- " 'capital-gain',\n",
- " 'capital-loss',\n",
- " 'hours-per-week',\n",
- " 'education-num',\n",
+ " \"age\",\n",
+ " \"capital-gain\",\n",
+ " \"capital-loss\",\n",
+ " \"hours-per-week\",\n",
+ " \"education-num\",\n",
"]\n",
"\n",
- "LABEL_KEY = 'label'\n",
+ "LABEL_KEY = \"label\"\n",
"\n",
"RAW_DATA_FEATURE_SPEC = dict(\n",
- " [(name, tf.io.FixedLenFeature([], tf.string))\n",
- " for name in CATEGORICAL_FEATURE_KEYS] +\n",
- " [(name, tf.io.FixedLenFeature([], tf.float32))\n",
- " for name in NUMERIC_FEATURE_KEYS] +\n",
- " [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]\n",
+ " [(name, tf.io.FixedLenFeature([], tf.string)) for name in CATEGORICAL_FEATURE_KEYS]\n",
+ " + [(name, tf.io.FixedLenFeature([], tf.float32)) for name in NUMERIC_FEATURE_KEYS]\n",
+ " + [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]\n",
")\n",
"\n",
"SCHEMA = tft.tf_metadata.dataset_metadata.DatasetMetadata(\n",
- " tft.tf_metadata.schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)).schema"
+ " tft.tf_metadata.schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)\n",
+ ").schema"
]
},
{
@@ -736,7 +770,7 @@
},
"outputs": [],
"source": [
- "pd.read_csv(train_data_file, names = ORDERED_CSV_COLUMNS).head()"
+ "pd.read_csv(train_data_file, names=ORDERED_CSV_COLUMNS).head()"
]
},
{
@@ -784,10 +818,9 @@
},
"outputs": [],
"source": [
- "from tfx_bsl.public import tfxio\n",
+ "import apache_beam as beam\n",
"from tfx_bsl.coders.example_coder import RecordBatchToExamples\n",
- "\n",
- "import apache_beam as beam"
+ "from tfx_bsl.public import tfxio"
]
},
{
@@ -801,15 +834,16 @@
"pipeline = beam.Pipeline()\n",
"\n",
"csv_tfxio = tfxio.BeamRecordCsvTFXIO(\n",
- " physical_format='text', column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA)\n",
+ " physical_format=\"text\", column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA\n",
+ ")\n",
"\n",
"raw_data = (\n",
" pipeline\n",
- " | 'ReadTrainData' \u003e\u003e beam.io.ReadFromText(\n",
- " train_data_file, coder=beam.coders.BytesCoder())\n",
- " | 'FixCommasTrainData' \u003e\u003e beam.Map(\n",
- " lambda line: line.replace(b', ', b','))\n",
- " | 'DecodeTrainData' \u003e\u003e csv_tfxio.BeamSource())"
+ " | \"ReadTrainData\"\n",
+ " >> beam.io.ReadFromText(train_data_file, coder=beam.coders.BytesCoder())\n",
+ " | \"FixCommasTrainData\" >> beam.Map(lambda line: line.replace(b\", \", b\",\"))\n",
+ " | \"DecodeTrainData\" >> csv_tfxio.BeamSource()\n",
+ ")"
]
},
{
@@ -842,13 +876,15 @@
},
"outputs": [],
"source": [
- "csv_tfxio = tfxio.CsvTFXIO(train_data_file,\n",
- " telemetry_descriptors=[], #???\n",
- " column_names=ORDERED_CSV_COLUMNS,\n",
- " schema=SCHEMA)\n",
+ "csv_tfxio = tfxio.CsvTFXIO(\n",
+ " train_data_file,\n",
+ " telemetry_descriptors=[], # ???\n",
+ " column_names=ORDERED_CSV_COLUMNS,\n",
+ " schema=SCHEMA,\n",
+ ")\n",
"\n",
"p2 = beam.Pipeline()\n",
- "raw_data_2 = p2 | 'TFXIORead' \u003e\u003e csv_tfxio.BeamSource()"
+ "raw_data_2 = p2 | \"TFXIORead\" >> csv_tfxio.BeamSource()"
]
},
{
@@ -871,39 +907,42 @@
"source": [
"NUM_OOV_BUCKETS = 1\n",
"\n",
- "def preprocessing_fn(inputs):\n",
- " \"\"\"Preprocess input columns into transformed columns.\"\"\"\n",
- " # Since we are modifying some features and leaving others unchanged, we\n",
- " # start by setting `outputs` to a copy of `inputs.\n",
- " outputs = inputs.copy()\n",
- "\n",
- " # Scale numeric columns to have range [0, 1].\n",
- " for key in NUMERIC_FEATURE_KEYS:\n",
- " outputs[key] = tft.scale_to_0_1(outputs[key])\n",
- "\n",
- " # For all categorical columns except the label column, we generate a\n",
- " # vocabulary but do not modify the feature. This vocabulary is instead\n",
- " # used in the trainer, by means of a feature column, to convert the feature\n",
- " # from a string to an integer id.\n",
- " for key in CATEGORICAL_FEATURE_KEYS:\n",
- " outputs[key] = tft.compute_and_apply_vocabulary(\n",
- " tf.strings.strip(inputs[key]),\n",
- " num_oov_buckets=NUM_OOV_BUCKETS,\n",
- " vocab_filename=key)\n",
"\n",
- " # For the label column we provide the mapping from string to index.\n",
- " with tf.init_scope():\n",
- " # `init_scope` - Only initialize the table once.\n",
- " initializer = tf.lookup.KeyValueTensorInitializer(\n",
- " keys=['\u003e50K', '\u003c=50K'],\n",
- " values=tf.cast(tf.range(2), tf.int64),\n",
- " key_dtype=tf.string,\n",
- " value_dtype=tf.int64)\n",
- " table = tf.lookup.StaticHashTable(initializer, default_value=-1)\n",
- "\n",
- " outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])\n",
- "\n",
- " return outputs"
+ "def preprocessing_fn(inputs):\n",
+ " \"\"\"Preprocess input columns into transformed columns.\"\"\"\n",
+ " # Since we are modifying some features and leaving others unchanged, we\n",
+ " # start by setting `outputs` to a copy of `inputs.\n",
+ " outputs = inputs.copy()\n",
+ "\n",
+ " # Scale numeric columns to have range [0, 1].\n",
+ " for key in NUMERIC_FEATURE_KEYS:\n",
+ " outputs[key] = tft.scale_to_0_1(outputs[key])\n",
+ "\n",
+ " # For all categorical columns except the label column, we generate a\n",
+ " # vocabulary but do not modify the feature. This vocabulary is instead\n",
+ " # used in the trainer, by means of a feature column, to convert the feature\n",
+ " # from a string to an integer id.\n",
+ " for key in CATEGORICAL_FEATURE_KEYS:\n",
+ " outputs[key] = tft.compute_and_apply_vocabulary(\n",
+ " tf.strings.strip(inputs[key]),\n",
+ " num_oov_buckets=NUM_OOV_BUCKETS,\n",
+ " vocab_filename=key,\n",
+ " )\n",
+ "\n",
+ " # For the label column we provide the mapping from string to index.\n",
+ " with tf.init_scope():\n",
+ " # `init_scope` - Only initialize the table once.\n",
+ " initializer = tf.lookup.KeyValueTensorInitializer(\n",
+ " keys=[\">50K\", \"<=50K\"],\n",
+ " values=tf.cast(tf.range(2), tf.int64),\n",
+ " key_dtype=tf.string,\n",
+ " value_dtype=tf.int64,\n",
+ " )\n",
+ " table = tf.lookup.StaticHashTable(initializer, default_value=-1)\n",
+ "\n",
+ " outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])\n",
+ "\n",
+ " return outputs"
]
},
{
@@ -913,8 +952,8 @@
},
"source": [
"One difference from the previous example is the label column manually specifies\n",
- "the mapping from the string to an index. So `'\u003e50'` is mapped to `0` and\n",
- "`'\u003c=50K'` is mapped to `1` because it's useful to know which index in the\n",
+ "the mapping from the string to an index. So `'>50'` is mapped to `0` and\n",
+ "`'<=50K'` is mapped to `1` because it's useful to know which index in the\n",
"trained model corresponds to which label.\n",
"\n",
"The `record_batches` variable represents a `PCollection` of\n",
@@ -948,9 +987,12 @@
"source": [
"working_dir = tempfile.mkdtemp()\n",
"with tft_beam.Context(temp_dir=working_dir):\n",
- " transformed_dataset, transform_fn = (\n",
- " raw_dataset | tft_beam.AnalyzeAndTransformDataset(\n",
- " preprocessing_fn, output_record_batches=True))"
+ " (\n",
+ " transformed_dataset,\n",
+ " transform_fn,\n",
+ " ) = raw_dataset | tft_beam.AnalyzeAndTransformDataset(\n",
+ " preprocessing_fn, output_record_batches=True\n",
+ " )"
]
},
{
@@ -976,10 +1018,11 @@
"\n",
"_ = (\n",
" transformed_data\n",
- " | 'EncodeTrainData' \u003e\u003e\n",
- " beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))\n",
- " | 'WriteTrainData' \u003e\u003e beam.io.WriteToTFRecord(\n",
- " os.path.join(output_dir , 'transformed.tfrecord')))"
+ " | \"EncodeTrainData\"\n",
+ " >> beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))\n",
+ " | \"WriteTrainData\"\n",
+ " >> beam.io.WriteToTFRecord(os.path.join(output_dir, \"transformed.tfrecord\"))\n",
+ ")"
]
},
{
@@ -1000,9 +1043,7 @@
},
"outputs": [],
"source": [
- "_ = (\n",
- " transform_fn\n",
- " | 'WriteTransformFn' \u003e\u003e tft_beam.WriteTransformFn(output_dir))"
+ "_ = transform_fn | \"WriteTransformFn\" >> tft_beam.WriteTransformFn(output_dir)"
]
},
{
@@ -1025,6 +1066,17 @@
"result = pipeline.run().wait_until_finish()"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6457b1a7dea1"
+ },
+ "outputs": [],
+ "source": [
+ "print(pipeline)"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {