Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensors for validation_data in fit #8388

Closed
wants to merge 7 commits into from

Conversation

sinjax
Copy link

@sinjax sinjax commented Nov 4, 2017

In our usecase it is convenient to read both our validation data and training data from TFRecords and so as raw tensorflow tensors. Tensors as input is supported through a combination of compile + fit calls. This merge request adds support for Tensors in validation_data

Not sure what conversation this is a part of, there are a few conversations hovering around using TFRecords (e.g. #6928 and #8287)

The first one is quite involved and this approach might well have already been suggested but here's my attempt anyway.

Basic idea is that we check to see whether the validation data is tensors in the training.Model#fit function. If it is, we construct a new Model instance in place and use it to:

  • create a test_function which can be used for validation
  • set the callback_model so a new implementation of TensorBoard (called TensorBoardWithTensorInputs) can read validation data directly as a tensor

The supported use case is:

X, Y = create_training_tensors()
Xv, Yv = create_validation_tensors()

model_input = Input(tensor=X)
model_output = some_keras_model(model_input)
train_model = Model(inputs=model_input, outputs=model_output)

train_model.compile(optimizer="adam",
                    loss='categorical_crossentropy',
                    metrics=['accuracy'],
                    target_tensors=[Y])
train_model.fit(epochs=100,
                steps_per_epoch=2, 
                validation_data=(Xv, Yv), validation_steps=2, callbacks=callbacks)

The code above has the effect of reading the validation tensors twice and outputting the correct metrics etc

While I'm here, an aside:
Upon closer inspection I noticed validation_steps was already specified, but the current behaviour seemed strange. It would seem if your specify validation_steps without setting validation_data the input X and Y in the example above would be read validation_steps times. I suppose one possible use case would be the careful construction of the X and Y tensors such that they contained all training examples followed by validation examples with precise use of steps_per_epoch and validation_steps to make sure the right data is read at the right time. It feels like this would be tricky to pull off at best

@fchollet
Copy link
Collaborator

fchollet commented Nov 8, 2017

There are serious architecture decisions involved here, and as we add this feature then we should add at the same time the ability to pass symbolic tensors to fit. Please create a design doc so we can review what the implementation choices and code structure should be.

The PR itself has a couple issues:

  • many unwarranted style changes (possibly made by an auto-formatter that doesn't follow PEP8 conventions?)
  • private methods should be prefixed with _

@fchollet fchollet closed this Nov 8, 2017
@sinjax
Copy link
Author

sinjax commented Nov 8, 2017

Adding the ability for pass symbolic vectors to fit in general is definitely a desirable feature. Our internal branch of keras contains an implementation which broadly follows the same idea implemented here, namely a check to see if x, y are tensors followed by a internal construction of a new Model.

This might be one way forward, alternatively one could imagine a fit_tensor which would reduce the number of changes needed for fit and might match workflows better anyway. Folks are likely to very purposefully be using symbolic vectors with keras models, so a separate function call might make sense

I'm happy to write up such considerations, can you point me to an existing design document so I know what kind of flow to follow?

@sinjax
Copy link
Author

sinjax commented Nov 8, 2017

As for the PEP8 issues, I turned on PyCharms pep8 checking to produce errors. I then went through fixing anything highlighted, in future changes I'll limit such changes to code I've actually touched

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants