-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpandas2tf.py
39 lines (27 loc) · 991 Bytes
/
pandas2tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
def train_input_fn(data, batch_size):
"""An input function for training"""
features, labels = data
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Repeat, and batch the examples.
dataset = dataset.repeat().batch(batch_size)
# Return the dataset.
return dataset
def eval_input_fn(data, batch_size):
"""An input function for evaluation or prediction"""
inputs = None
try:
# If there are labels, then we evaluate
features, labels = data
inputs = (dict(features), labels)
except:
# If there are no labels --> prediction
inputs = dict(data)
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices(inputs)
# Batch the examples
assert batch_size is not None, "batch_size must not be None"
dataset = dataset.batch(batch_size)
# Return the dataset.
return dataset