Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to only transform raw data when requested. #268

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

pritamdodeja
Copy link

Scenario

When there is already pre-processed data available, and the user wants to re-use that data by passing read_raw_data_for_training=False to main, the flow was calling common.transform_data again on the raw data. This was causing WriteTransformFn to fail because there are already existing artifacts there, and unnecessarily recomputing statistics etc.

Fix

This fix moves the common.transform_data invocation to where we are processing the raw data for the first time.

When read_raw_data_for_training is set to False when invoking the main
function, common.transform_data was being called on raw train and test
data anyway.  This fix moves the transformation to the block where
read_raw_data_for_training is True. The scenario here is the data has
already been preprocessed, and the user wishes to re-use that
preprocessed data.
@zoyahav zoyahav self-requested a review May 6, 2022 10:48
@zoyahav
Copy link
Member

zoyahav commented May 6, 2022

Thanks for the PR!
In this scenario, where was transform_data called the first time? Or is this describing a case where the main function is called again?
If the latter, this example is meant to showcase an end-to-end example of using TFT for preprocessing, and using its outputs to train a model. We shouldn't sking TFT analysis, if called again then it's expected that the user either gives a fresh output path or clears previous outputs.

It's true that transforming the raw data, serializing and writing the results is unnecessary when read_raw_data_for_training=True, how about skipping this part, and just using tft_beam.AnalyzeDataset for this case?

# The TFXIO output format is chosen for improved performance.
transformed_dataset, transform_fn = (
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn, output_record_batches=True))
# Transformed metadata is not necessary for encoding.
transformed_data, _ = transformed_dataset
# Extract transformed RecordBatches, encode and write them to the given
# directory.
coder = RecordBatchToExamplesEncoder()
_ = (
transformed_data
| 'EncodeTrainData' >>
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
| 'WriteTrainData' >> beam.io.WriteToTFRecord(
os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)))
# Now apply transform function to test data. In this case we remove the
# trailing period at the end of each line, and also ignore the header line
# that is present in the test data file.
raw_test_data = (
pipeline
| 'ReadTestData' >> beam.io.ReadFromText(
test_data_file, skip_header_lines=1,
coder=beam.coders.BytesCoder())
| 'FixCommasTestData' >> beam.Map(
lambda line: line.replace(b', ', b','))
| 'RemoveTrailingPeriodsTestData' >> beam.Map(lambda line: line[:-1])
| 'DecodeTestData' >> csv_tfxio.BeamSource())
raw_test_dataset = (raw_test_data, csv_tfxio.TensorAdapterConfig())
# The TFXIO output format is chosen for improved performance.
transformed_test_dataset = (
(raw_test_dataset, transform_fn)
| tft_beam.TransformDataset(output_record_batches=True))
# Transformed metadata is not necessary for encoding.
transformed_test_data, _ = transformed_test_dataset
# Extract transformed RecordBatches, encode and write them to the given
# directory.
_ = (
transformed_test_data
| 'EncodeTestData' >>
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
| 'WriteTestData' >> beam.io.WriteToTFRecord(
os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)))

We can even always call tft_beam.AnalyzeDataset, and only call tft_beam.TransformDataset in case read_raw_data_for_training=False (followed by data serialization and writing).

@pritamdodeja
Copy link
Author

The behavior of the default usage won't change, as main's signature is:

def main(input_data_dir,                                                                                                                                                                  
         working_dir,                                                                                                                                                                     
         read_raw_data_for_training=True,                                                                                                                                                 
         num_train_instances=common.NUM_TRAIN_INSTANCES,                                                                                                                                  
         num_test_instances=common.NUM_TEST_INSTANCES):

so transform_data would be correctly called in that scenario. If they were to re-run the entire script again by setting read_raw_data_for_training=False with the intention of continuing training using the same pre-processed training and test data, then the current implementation would not work as intended. I think read_raw_data_for_training could be better named as preprocess_raw_data as I think that's what it's doing here.

To address your questions, yes, this would be the second invocation of main. I think this example could be enhanced to represent the lifecycle of this process in that preprocessing is done ahead of time, and then training occurs later, following which new data comes in. The benefit of the change I'm proposing is for anyone who uses this as a template to build out their process; it does not disrupt the default behavior, and it supports marginally better separation between preprocessing and training in the lifecycle. I think a lot of the code already implies this lifecycle, it just needs to be built out. Here are some enhancements I was thinking about:

  1. Automate the getting of the required data (convenience)
  2. I am unsure if education-num is sparse in the beginning, I don't believe it needs special treatment here (simplification)
  3. I am pretty sure it's not sparse at the end, since it's been encoded at this point (simplification)
  4. The shape related code is too complex and can be simplified because of the two points above as mentioned in TODO(b/208879020) (simplification)
  5. Continue training on a model that's been saved to disk (enhancement)
  6. Use the preprocessing function that has been stored on disk to process new data using TransformDataset (enhancement)

I am pretty new to github/distributed development, so apologies if I'm not structuring my questions/suggestions properly. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants